mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Project] support SAM inferencer (#2897)
This commit is contained in:
parent
1271db2cff
commit
20f80f9162
40
projects/sam_inference_demo/README.md
Normal file
40
projects/sam_inference_demo/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# Introducing the Segment Anything Model (SAM) Inference Demo!
|
||||
|
||||
Welcome to the Segment Anything (SA) Inference Demo, a user-friendly implementation based on the original Segment Anything project. Our demo allows you to experience the power and versatility of the Segment Anything Model (SAM) through an easy-to-use API.
|
||||
|
||||
With this inference demo, you can explore the capabilities of the Segment Anything Model and witness its effectiveness in various tasks and image distributions. For more information on the original project, dataset, and model, please visit the official website at https://segment-anything.com.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.10
|
||||
- PyTorch 1.13
|
||||
- MMEngine >= v0.7.2
|
||||
- MMCV >= v2.0.0
|
||||
|
||||
### Installation
|
||||
|
||||
We assume that you have already installed PyTorch. If not, please follow the instructions on the [PyTorch website](https://pytorch.org/).
|
||||
|
||||
**1. Install MMEngine & MMCV**
|
||||
|
||||
```shell
|
||||
pip install openmim
|
||||
mim install mmengine
|
||||
mim install 'mmcv>=2.0.0'
|
||||
```
|
||||
|
||||
**2. Install MMPretrain**
|
||||
|
||||
```shell
|
||||
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
|
||||
```
|
||||
|
||||
**3. Install MMSegmentation**
|
||||
|
||||
```shell
|
||||
pip install mmsegmentation
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
Open the `sam_image_demo.ipynb` notebook and follow the instructions to run the demo.
|
2
projects/sam_inference_demo/sam/__init__.py
Normal file
2
projects/sam_inference_demo/sam/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .modeling import * # noqa
|
||||
from .utils import * # noqa
|
12
projects/sam_inference_demo/sam/modeling/__init__.py
Normal file
12
projects/sam_inference_demo/sam/modeling/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
from .sam import SAM
|
||||
from .transformer import TwoWayTransformer
|
||||
|
||||
__all__ = ['SAM', 'MaskDecoder', 'PromptEncoder', 'TwoWayTransformer']
|
45
projects/sam_inference_demo/sam/modeling/common.py
Normal file
45
projects/sam_inference_demo/sam/modeling/common.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
196
projects/sam_inference_demo/sam/modeling/mask_decoder.py
Normal file
196
projects/sam_inference_demo/sam/modeling/mask_decoder.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: dict,
|
||||
num_multimask_outputs: int = 3,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""Predicts masks given an image and prompt embeddings, using a
|
||||
tranformer architecture.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
Arguments:
|
||||
transformer_dim (int): the channel dimension of the transformer
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict
|
||||
when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when
|
||||
upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict
|
||||
mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
||||
used to predict mask quality
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = MODELS.build(transformer)
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
activation = MODELS.build(act_cfg)
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
transformer_dim, transformer_dim // 4, kernel_size=2,
|
||||
stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation,
|
||||
nn.ConvTranspose2d(
|
||||
transformer_dim // 4,
|
||||
transformer_dim // 8,
|
||||
kernel_size=2,
|
||||
stride=2),
|
||||
activation,
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList([
|
||||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
||||
for i in range(self.num_mask_tokens)
|
||||
])
|
||||
|
||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim,
|
||||
self.num_mask_tokens, iou_head_depth)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: Tensor,
|
||||
image_pe: Tensor,
|
||||
sparse_prompt_embeddings: Tensor,
|
||||
dense_prompt_embeddings: Tensor,
|
||||
multimask_output: bool,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Predict masks given image and prompt embeddings.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
Arguments:
|
||||
image_embeddings (Tensor): the embeddings from the image encoder
|
||||
image_pe (Tensor): positional encoding with the shape of
|
||||
image_embeddings
|
||||
sparse_prompt_embeddings (Tensor): the embeddings of
|
||||
the points and boxes
|
||||
dense_prompt_embeddings (Tensor): the embeddings of the mask inputs
|
||||
multimask_output (bool): Whether to return multiple masks or a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
Tensor: batched predicted masks
|
||||
Tensor: batched predictions of mask quality
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
mask_slice = slice(1, None)
|
||||
else:
|
||||
mask_slice = slice(0, 1)
|
||||
masks = masks[:, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: Tensor,
|
||||
image_pe: Tensor,
|
||||
sparse_prompt_embeddings: Tensor,
|
||||
dense_prompt_embeddings: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat(
|
||||
[self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(
|
||||
sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: List[Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](
|
||||
mask_tokens_out[:, i, :]))
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
|
||||
b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
# Lightly adapted from
|
||||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = F.sigmoid(x)
|
||||
return x
|
227
projects/sam_inference_demo/sam/modeling/prompt_encoder.py
Normal file
227
projects/sam_inference_demo/sam/modeling/prompt_encoder.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .common import LayerNorm2d
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PromptEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: Tuple[int, int],
|
||||
input_image_size: Tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Arguments:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
point_embeddings = [
|
||||
nn.Embedding(1, embed_dim)
|
||||
for i in range(self.num_point_embeddings)
|
||||
]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (4 * image_embedding_size[0],
|
||||
4 * image_embedding_size[1])
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(
|
||||
mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
pad: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2),
|
||||
device=points.device)
|
||||
padding_label = -torch.ones(
|
||||
(labels.shape[0], 1), device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(
|
||||
points, self.input_image_size)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(
|
||||
coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
mask_embedding = self.mask_downscaling(masks)
|
||||
return mask_embedding
|
||||
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""Gets the batch size of the output given the batch size of the input
|
||||
prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
|
||||
Arguments:
|
||||
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
||||
and labels to embed.
|
||||
boxes (torch.Tensor or none): boxes to embed
|
||||
masks (torch.Tensor or none): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
""" # noqa
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim),
|
||||
device=self._get_device())
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(
|
||||
coords, labels, pad=(boxes is None))
|
||||
sparse_embeddings = torch.cat(
|
||||
[sparse_embeddings, point_embeddings], dim=1)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings],
|
||||
dim=1)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(
|
||||
1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0],
|
||||
self.image_embedding_size[1])
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self,
|
||||
num_pos_feats: int = 64,
|
||||
scale: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer(
|
||||
'positional_encoding_gaussian_matrix',
|
||||
scale * torch.randn((2, num_pos_feats)),
|
||||
)
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape # noqa
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
device: Any = self.positional_encoding_gaussian_matrix.device
|
||||
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
||||
y_embed = grid.cumsum(dim=0) - 0.5
|
||||
x_embed = grid.cumsum(dim=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
||||
return pe.permute(2, 0, 1) # C x H x W
|
||||
|
||||
def forward_with_coords(self, coords_input: torch.Tensor,
|
||||
image_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input.clone()
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
188
projects/sam_inference_demo/sam/modeling/sam.py
Normal file
188
projects/sam_inference_demo/sam/modeling/sam.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .mask_decoder import MaskDecoder
|
||||
from .prompt_encoder import PromptEncoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SAM(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder_cfg: dict,
|
||||
prompt_encoder_cfg: dict,
|
||||
mask_decoder_cfg: dict,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
"""SAM predicts object masks from an image and input prompts. Borrowed
|
||||
from https://github.com/facebookresearch/segment-anything.
|
||||
|
||||
Arguments:
|
||||
image_encoder (ViTSAM): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask
|
||||
prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input
|
||||
prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
||||
and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the
|
||||
input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the
|
||||
input image.
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = MODELS.build(image_encoder_cfg)
|
||||
self.prompt_encoder: PromptEncoder = MODELS.build(prompt_encoder_cfg)
|
||||
self.mask_decoder: MaskDecoder = MODELS.build(mask_decoder_cfg)
|
||||
self.register_buffer('pixel_mean',
|
||||
torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('pixel_std',
|
||||
torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
@property
|
||||
def device(self) -> Any:
|
||||
return self.pixel_mean.device
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
batched_input: List[Dict[str, Any]],
|
||||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""Predicts masks end-to-end from provided images and prompts. If
|
||||
prompts are not known in advance, using SamPredictor is recommended
|
||||
over calling the model directly.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
Arguments:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
the image before transformation, as (H, W).
|
||||
'point_coords': (torch.Tensor) Batched point prompts for
|
||||
this image, with shape BxNx2. Already transformed to the
|
||||
input frame of the model.
|
||||
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
||||
with shape BxN.
|
||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
||||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determiend by multimask_output, and (H, W) is the
|
||||
original size of the image.
|
||||
'iou_predictions': (torch.Tensor) The model's predictions
|
||||
of mask quality, in shape BxC.
|
||||
'low_res_logits': (torch.Tensor) Low resolution logits with
|
||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
||||
to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = torch.stack(
|
||||
[self.preprocess(x['image']) for x in batched_input], dim=0)
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
|
||||
outputs = []
|
||||
for image_record, curr_embedding in zip(batched_input,
|
||||
image_embeddings):
|
||||
if 'point_coords' in image_record:
|
||||
points = (image_record['point_coords'],
|
||||
image_record['point_labels'])
|
||||
else:
|
||||
points = None
|
||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
||||
points=points,
|
||||
boxes=image_record.get('boxes', None),
|
||||
masks=image_record.get('mask_inputs', None),
|
||||
)
|
||||
low_res_masks, iou_predictions = self.mask_decoder(
|
||||
image_embeddings=curr_embedding.unsqueeze(0),
|
||||
image_pe=self.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
masks = self.postprocess_masks(
|
||||
low_res_masks,
|
||||
input_size=image_record['image'].shape[-2:],
|
||||
original_size=image_record['original_size'],
|
||||
)
|
||||
masks = masks > self.mask_threshold
|
||||
outputs.append({
|
||||
'masks': masks,
|
||||
'iou_predictions': iou_predictions,
|
||||
'low_res_logits': low_res_masks,
|
||||
})
|
||||
return outputs
|
||||
|
||||
def postprocess_masks(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
input_size: Tuple[int, ...],
|
||||
original_size: Tuple[int, ...],
|
||||
) -> torch.Tensor:
|
||||
"""Remove padding and upscale masks to the original image size.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything
|
||||
|
||||
Arguments:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
"""
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
self.image_encoder.img_size,
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., :input_size[0], :input_size[1]]
|
||||
masks = F.interpolate(
|
||||
masks, original_size, mode='bilinear', align_corners=False)
|
||||
return masks
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
# Normalize colors
|
||||
x = (x - self.pixel_mean) / self.pixel_std
|
||||
|
||||
# Pad
|
||||
h, w = x.shape[-2:]
|
||||
img_size = max(self.image_encoder.img_size)
|
||||
padh = img_size - h
|
||||
padw = img_size - w
|
||||
x = F.pad(x, (0, padw, 0, padh))
|
||||
return x
|
241
projects/sam_inference_demo/sam/modeling/transformer.py
Normal file
241
projects/sam_inference_demo/sam/modeling/transformer.py
Normal file
@ -0,0 +1,241 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .common import MLPBlock
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TwoWayTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""A transformer decoder that attends to an input image using queries
|
||||
whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
))
|
||||
|
||||
self.final_attn_token_to_image = Attention(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
downsample_rate=attention_downsample_rate)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: Tensor,
|
||||
image_pe: Tensor,
|
||||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape
|
||||
B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the processed point_embedding
|
||||
torch.Tensor: the processed image_embedding
|
||||
""" # noqa E501
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attenion layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to
|
||||
sparse inputs.
|
||||
|
||||
Arguments:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
downsample_rate=attention_downsample_rate)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(
|
||||
embedding_dim,
|
||||
num_heads,
|
||||
downsample_rate=attention_downsample_rate)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor,
|
||||
key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, 'num_heads must divide embedding_dim.' # noqa E501
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
688
projects/sam_inference_demo/sam/sam_inferencer.py
Normal file
688
projects/sam_inference_demo/sam/sam_inferencer.py
Normal file
@ -0,0 +1,688 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.runner.checkpoint import load_checkpoint
|
||||
# yapf: disable
|
||||
from sam.utils import (MaskData, area_from_rle, batch_iterator,
|
||||
batched_mask_to_box, box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids, calculate_stability_score,
|
||||
coco_encode_rle, generate_crop_boxes,
|
||||
is_box_near_crop_edge, mask_to_rle_pytorch,
|
||||
remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
|
||||
uncrop_masks, uncrop_points)
|
||||
from torchvision.ops.boxes import batched_nms, box_area
|
||||
|
||||
from mmseg.registry import MODELS, TRANSFORMS
|
||||
|
||||
# yapf: enable
|
||||
|
||||
model_zoo = {
|
||||
'base':
|
||||
'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-base-p16_3rdparty_sa1b-1024x1024_20230413-78a25eed.pth', # noqa
|
||||
'large':
|
||||
'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-large-p16_3rdparty_sa1b-1024x1024_20230413-940520da.pth', # noqa
|
||||
'huge':
|
||||
'https://download.openmmlab.com/mmsegmentation/v0.5/sam/sam_vit-huge-p16_3rdparty_sa1b-1024x1024_20230413-faaf96f6.pth', # noqa
|
||||
}
|
||||
|
||||
|
||||
class SAMInferencer:
|
||||
|
||||
def __init__(self, arch: str = 'base') -> None:
|
||||
assert arch in ['base', 'large', 'huge']
|
||||
self.model = self.init_model(arch)
|
||||
self.transform = TRANSFORMS.build(
|
||||
dict(
|
||||
type='ResizeLongestSide',
|
||||
target_length=max(self.model.image_encoder.img_size)))
|
||||
|
||||
def set_image(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
image_format: str = 'RGB',
|
||||
) -> None:
|
||||
"""Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image for calculating masks. Expects an
|
||||
image in HWC uint8 format, with pixel values in [0, 255].
|
||||
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
||||
"""
|
||||
assert image_format in [
|
||||
'RGB',
|
||||
'BGR',
|
||||
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
||||
if image_format != self.model.image_format:
|
||||
image = image[..., ::-1]
|
||||
|
||||
# Transform the image to the form expected by the model
|
||||
input_image = self.transform.apply_image(image)
|
||||
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
||||
input_image_torch = input_image_torch.permute(
|
||||
2, 0, 1).contiguous()[None, :, :, :]
|
||||
|
||||
self.set_torch_image(input_image_torch, image.shape[:2])
|
||||
|
||||
@torch.no_grad()
|
||||
def set_torch_image(
|
||||
self,
|
||||
transformed_image: torch.Tensor,
|
||||
original_image_size: Tuple[int, ...],
|
||||
) -> None:
|
||||
"""Calculates the image embeddings for the provided image, allowing
|
||||
masks to be predicted with the 'predict' method. Expects the input
|
||||
image to be already transformed to the format expected by the model.
|
||||
|
||||
Arguments:
|
||||
transformed_image (torch.Tensor): The input image, with shape
|
||||
1x3xHxW, which has been transformed with ResizeLongestSide.
|
||||
original_image_size (tuple(int, int)): The size of the image
|
||||
before transformation, in (H, W) format.
|
||||
"""
|
||||
assert (len(transformed_image.shape) == 4
|
||||
and transformed_image.shape[1] == 3
|
||||
and max(*transformed_image.shape[2:]) == max(
|
||||
self.model.image_encoder.img_size)
|
||||
), 'set_torch_image input must be BCHW with long side'
|
||||
f' {self.model.image_encoder.img_size}.'
|
||||
self.reset_image()
|
||||
|
||||
self.original_size = original_image_size
|
||||
self.input_size = tuple(transformed_image.shape[-2:])
|
||||
input_image = self.model.preprocess(transformed_image)
|
||||
self.features = self.model.image_encoder(input_image)[0]
|
||||
self.is_image_set = True
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
mask_input: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Predict masks for the given input prompts, using the currently set
|
||||
image.
|
||||
|
||||
Arguments:
|
||||
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (np.ndarray or None): A length N array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
box (np.ndarray or None): A length 4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form 1xHxW, where
|
||||
for SAM, H=W=256.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in CxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(np.ndarray): An array of length C containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(np.ndarray): An array of shape CxHxW, where C is the number
|
||||
of masks and H=W=256. These low resolution logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
""" # noqa
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
'An image must be set with .set_image(...) before mask'
|
||||
'prediction.')
|
||||
|
||||
# Transform input prompts
|
||||
coords_torch = None
|
||||
labels_torch = None
|
||||
box_torch = None
|
||||
mask_input_torch = None
|
||||
|
||||
if point_coords is not None:
|
||||
assert (
|
||||
point_labels is not None
|
||||
), 'point_labels must be supplied if point_coords is supplied.'
|
||||
point_coords = self.transform.apply_coords(point_coords,
|
||||
self.original_size)
|
||||
coords_torch = torch.as_tensor(
|
||||
point_coords, dtype=torch.float, device=self.device)
|
||||
labels_torch = torch.as_tensor(
|
||||
point_labels, dtype=torch.int, device=self.device)
|
||||
coords_torch, labels_torch = coords_torch[
|
||||
None, :, :], labels_torch[None, :]
|
||||
if box is not None:
|
||||
box = self.transform.apply_boxes(box, self.original_size)
|
||||
box_torch = torch.as_tensor(
|
||||
box, dtype=torch.float, device=self.device)
|
||||
box_torch = box_torch[None, :]
|
||||
if mask_input is not None:
|
||||
mask_input_torch = torch.as_tensor(
|
||||
mask_input, dtype=torch.float, device=self.device)
|
||||
mask_input_torch = mask_input_torch[None, :, :, :]
|
||||
|
||||
masks, iou_predictions, low_res_masks = self.predict_torch(
|
||||
coords_torch,
|
||||
labels_torch,
|
||||
box_torch,
|
||||
mask_input_torch,
|
||||
multimask_output,
|
||||
return_logits=return_logits,
|
||||
)
|
||||
|
||||
masks = masks[0].detach().cpu().numpy()
|
||||
iou_predictions = iou_predictions[0].detach().cpu().numpy()
|
||||
low_res_masks = low_res_masks[0].detach().cpu().numpy()
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_torch(
|
||||
self,
|
||||
point_coords: Optional[torch.Tensor],
|
||||
point_labels: Optional[torch.Tensor],
|
||||
boxes: Optional[torch.Tensor] = None,
|
||||
mask_input: Optional[torch.Tensor] = None,
|
||||
multimask_output: bool = True,
|
||||
return_logits: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Predict masks for the given input prompts, using the currently set
|
||||
image. Input prompts are batched torch tensors and are expected to
|
||||
already be transformed to the input frame using ResizeLongestSide.
|
||||
|
||||
Arguments:
|
||||
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
||||
model. Each point is in (X,Y) in pixels.
|
||||
point_labels (torch.Tensor or None): A BxN array of labels for the
|
||||
point prompts. 1 indicates a foreground point and 0 indicates a
|
||||
background point.
|
||||
box (np.ndarray or None): A Bx4 array given a box prompt to the
|
||||
model, in XYXY format.
|
||||
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
||||
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
||||
for SAM, H=W=256. Masks returned by a previous iteration of the
|
||||
predict method do not need further transformation.
|
||||
multimask_output (bool): If true, the model will return three masks.
|
||||
For ambiguous input prompts (such as a single click), this will often
|
||||
produce better masks than a single prediction. If only a single
|
||||
mask is needed, the model's predicted quality score can be used
|
||||
to select the best mask. For non-ambiguous prompts, such as multiple
|
||||
input prompts, multimask_output=False can give better results.
|
||||
return_logits (bool): If true, returns un-thresholded masks logits
|
||||
instead of a binary mask.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
||||
number of masks, and (H, W) is the original image size.
|
||||
(torch.Tensor): An array of shape BxC containing the model's
|
||||
predictions for the quality of each mask.
|
||||
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
||||
of masks and H=W=256. These low res logits can be passed to
|
||||
a subsequent iteration as mask input.
|
||||
""" # noqa
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
'An image must be set with .set_image(...) before mask '
|
||||
'prediction.')
|
||||
|
||||
if point_coords is not None:
|
||||
points = (point_coords, point_labels)
|
||||
else:
|
||||
points = None
|
||||
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
||||
points=points,
|
||||
boxes=boxes,
|
||||
masks=mask_input,
|
||||
)
|
||||
|
||||
# Predict masks
|
||||
low_res_masks, iou_predictions = self.model.mask_decoder(
|
||||
image_embeddings=self.features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# Upscale the masks to the original image resolution
|
||||
masks = self.model.postprocess_masks(low_res_masks, self.input_size,
|
||||
self.original_size)
|
||||
|
||||
if not return_logits:
|
||||
masks = masks > self.model.mask_threshold
|
||||
|
||||
return masks, iou_predictions, low_res_masks
|
||||
|
||||
def get_image_embedding(self) -> torch.Tensor:
|
||||
"""Returns the image embeddings for the currently set image, with shape
|
||||
1xCxHxW, where C is the embedding dimension and (H,W) are the embedding
|
||||
spatial dimension of SAM (typically C=256, H=W=64)."""
|
||||
if not self.is_image_set:
|
||||
raise RuntimeError(
|
||||
'An image must be set with .set_image(...) to generate an '
|
||||
'embedding.')
|
||||
assert self.features is not None, 'Features must exist if an image has'
|
||||
' been set.'
|
||||
return self.features
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.model.device
|
||||
|
||||
def reset_image(self) -> None:
|
||||
"""Resets the currently set image."""
|
||||
self.is_image_set = False
|
||||
self.features = None
|
||||
self.orig_h = None
|
||||
self.orig_w = None
|
||||
self.input_h = None
|
||||
self.input_w = None
|
||||
|
||||
def init_model(self, arch: str):
|
||||
model = MODELS.build(
|
||||
dict(
|
||||
type='SAM',
|
||||
image_encoder_cfg=dict(
|
||||
type='mmpretrain.ViTSAM',
|
||||
arch=arch,
|
||||
img_size=1024,
|
||||
patch_size=16,
|
||||
out_channels=256,
|
||||
use_abs_pos=True,
|
||||
use_rel_pos=True,
|
||||
window_size=14,
|
||||
),
|
||||
prompt_encoder_cfg=dict(
|
||||
type='PromptEncoder',
|
||||
embed_dim=256,
|
||||
image_embedding_size=(64, 64),
|
||||
input_image_size=(1024, 1024),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder_cfg=dict(
|
||||
type='MaskDecoder',
|
||||
num_multimask_outputs=3,
|
||||
transformer=dict(
|
||||
type='TwoWayTransformer',
|
||||
depth=2,
|
||||
embedding_dim=256,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=256,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
)))
|
||||
load_checkpoint(model, model_zoo.get(arch), strict=True)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
class SamAutomaticMaskGenerator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch: str = 'base',
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 64,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.95,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[np.ndarray]] = None,
|
||||
min_mask_region_area: int = 0,
|
||||
output_mode: str = 'binary_mask',
|
||||
) -> None:
|
||||
"""Using a SAM model, generates masks for the entire image. Generates a
|
||||
grid of point prompts over the image, then filters low quality and
|
||||
duplicate masks. The default settings are chosen for SAM with a ViT-H
|
||||
backbone.
|
||||
|
||||
Arguments:
|
||||
arch (str): The SAM model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crops_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
""" # noqa
|
||||
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), 'Exactly one of points_per_side or point_grid must be provided.'
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError(
|
||||
"Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
'binary_mask',
|
||||
'uncompressed_rle',
|
||||
'coco_rle',
|
||||
], f'Unknown output_mode {output_mode}.'
|
||||
if output_mode == 'coco_rle':
|
||||
from pycocotools import \
|
||||
mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = SAMInferencer(arch)
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
""" # noqa
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
|
||||
# Encode masks
|
||||
if self.output_mode == 'coco_rle':
|
||||
mask_data['segmentations'] = [
|
||||
coco_encode_rle(rle) for rle in mask_data['rles']
|
||||
]
|
||||
elif self.output_mode == 'binary_mask':
|
||||
mask_data['segmentations'] = [
|
||||
rle_to_mask(rle) for rle in mask_data['rles']
|
||||
]
|
||||
else:
|
||||
mask_data['segmentations'] = mask_data['rles']
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data['segmentations'])):
|
||||
ann = {
|
||||
'segmentation':
|
||||
mask_data['segmentations'][idx],
|
||||
'area':
|
||||
area_from_rle(mask_data['rles'][idx]),
|
||||
'bbox':
|
||||
box_xyxy_to_xywh(mask_data['boxes'][idx]).tolist(),
|
||||
'predicted_iou':
|
||||
mask_data['iou_preds'][idx].item(),
|
||||
'point_coords': [mask_data['points'][idx].tolist()],
|
||||
'stability_score':
|
||||
mask_data['stability_score'][idx].item(),
|
||||
'crop_box':
|
||||
box_xyxy_to_xywh(mask_data['crop_boxes'][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[:2]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(orig_size,
|
||||
self.crop_n_layers,
|
||||
self.crop_overlap_ratio)
|
||||
|
||||
# Iterate over image crops
|
||||
data = MaskData()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx,
|
||||
orig_size)
|
||||
data.cat(crop_data)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data['crop_boxes'])
|
||||
scores = scores.to(data['boxes'].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data['boxes'].float(),
|
||||
scores,
|
||||
torch.zeros(len(data['boxes'])), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[:2]
|
||||
self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
for (points, ) in batch_iterator(self.points_per_batch,
|
||||
points_for_image):
|
||||
batch_data = self._process_batch(points, cropped_im_size, crop_box,
|
||||
orig_size)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
self.predictor.reset_image()
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
keep_by_nms = batched_nms(
|
||||
data['boxes'].float(),
|
||||
data['iou_preds'],
|
||||
torch.zeros(len(data['boxes'])), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
# Return to the original image frame
|
||||
data['boxes'] = uncrop_boxes_xyxy(data['boxes'], crop_box)
|
||||
data['points'] = uncrop_points(data['points'], crop_box)
|
||||
data['crop_boxes'] = torch.tensor(
|
||||
[crop_box for _ in range(len(data['rles']))])
|
||||
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
# Run model on this batch
|
||||
transformed_points = self.predictor.transform.apply_coords(
|
||||
points, im_size)
|
||||
in_points = torch.as_tensor(
|
||||
transformed_points, device=self.predictor.device)
|
||||
in_labels = torch.ones(
|
||||
in_points.shape[0], dtype=torch.int, device=in_points.device)
|
||||
masks, iou_preds, _ = self.predictor.predict_torch(
|
||||
in_points[:, None, :],
|
||||
in_labels[:, None],
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
|
||||
# Serialize predictions and store in MaskData
|
||||
data = MaskData(
|
||||
masks=masks.flatten(0, 1),
|
||||
iou_preds=iou_preds.flatten(0, 1),
|
||||
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
||||
)
|
||||
del masks
|
||||
|
||||
# Filter by predicted IoU
|
||||
if self.pred_iou_thresh > 0.0:
|
||||
keep_mask = data['iou_preds'] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data['stability_score'] = calculate_stability_score(
|
||||
data['masks'], self.predictor.model.mask_threshold,
|
||||
self.stability_score_offset)
|
||||
if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data['stability_score'] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data['masks'] = data['masks'] > self.predictor.model.mask_threshold
|
||||
data['boxes'] = batched_mask_to_box(data['masks'])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(data['boxes'], crop_box,
|
||||
[0, 0, orig_w, orig_h])
|
||||
if not torch.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data['masks'] = uncrop_masks(data['masks'], crop_box, orig_h, orig_w)
|
||||
data['rles'] = mask_to_rle_pytorch(data['masks'])
|
||||
del data['masks']
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(mask_data: MaskData, min_area: int,
|
||||
nms_thresh: float) -> MaskData:
|
||||
"""Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data['rles']) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data['rles']:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(
|
||||
mask, min_area, mode='islands')
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros(len(boxes)), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_torch = masks[i_mask].unsqueeze(0)
|
||||
mask_data['rles'][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
||||
mask_data['boxes'][i_mask] = boxes[
|
||||
i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
2
projects/sam_inference_demo/sam/utils/__init__.py
Normal file
2
projects/sam_inference_demo/sam/utils/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .amg import * # noqa: F403 F401
|
||||
from .transforms import ResizeLongestSide # noqa: F403 F401
|
355
projects/sam_inference_demo/sam/utils/amg.py
Normal file
355
projects/sam_inference_demo/sam/utils/amg.py
Normal file
@ -0,0 +1,355 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# https://github.com/facebookresearch/segment-anything
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""A structure for storing masks and their related data in batched format.
|
||||
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), 'MaskData only supports list, numpy arrays, and torch tensors.'
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), 'MaskData only supports list, numpy arrays, and torch tensors.'
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: torch.Tensor) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep.detach().cpu().numpy()]
|
||||
elif isinstance(v, list) and keep.dtype == torch.bool:
|
||||
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep]
|
||||
else:
|
||||
raise TypeError(
|
||||
f'MaskData key {k} has an unsupported type {type(v)}.')
|
||||
|
||||
def cat(self, new_stats: 'MaskData') -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'MaskData key {k} has an unsupported type {type(v)}.')
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v.detach().cpu().numpy()
|
||||
|
||||
|
||||
def is_box_near_crop_edge(boxes: torch.Tensor,
|
||||
crop_box: List[int],
|
||||
orig_box: List[int],
|
||||
atol: float = 20.0) -> torch.Tensor:
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original
|
||||
image."""
|
||||
crop_box_torch = torch.as_tensor(
|
||||
crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(
|
||||
orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(
|
||||
boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(
|
||||
boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
||||
box_xywh = deepcopy(box_xyxy)
|
||||
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
||||
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
||||
return box_xywh
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in
|
||||
args), 'Batched iteration must have inputs of all the same size.'
|
||||
n_batches = len(args[0]) // batch_size + int(
|
||||
len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size:(b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""Encodes masks to an uncompressed RLE, in the format expected by pycoco
|
||||
tools."""
|
||||
# Put in fortran order and flatten h,w
|
||||
b, h, w = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(b):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
||||
cur_idxs = torch.cat([
|
||||
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
cur_idxs + 1,
|
||||
torch.tensor([h * w], dtype=cur_idxs.dtype,
|
||||
device=cur_idxs.device),
|
||||
])
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if tensor[i, 0] == 0 else [0]
|
||||
counts.extend(btw_idxs.detach().cpu().tolist())
|
||||
out.append({'size': [h, w], 'counts': counts})
|
||||
return out
|
||||
|
||||
|
||||
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
h, w = rle['size']
|
||||
mask = np.empty(h * w, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle['counts']:
|
||||
mask[idx:idx + count] = parity
|
||||
idx += count
|
||||
parity ^= True
|
||||
mask = mask.reshape(w, h)
|
||||
return mask.transpose() # Put in C order
|
||||
|
||||
|
||||
def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle['counts'][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float,
|
||||
threshold_offset: float) -> torch.Tensor:
|
||||
"""Computes the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between the binary masks obtained by
|
||||
thresholding the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = ((masks > (mask_threshold + threshold_offset)).sum(
|
||||
-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
||||
unions = ((masks > (mask_threshold - threshold_offset)).sum(
|
||||
-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int,
|
||||
scale_per_layer: int) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(build_point_grid(n_points))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int,
|
||||
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
|
||||
"""Generates a list of crop boxes of different sizes.
|
||||
|
||||
Each layer has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2**(i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [
|
||||
int((crop_w - overlap) * i) for i in range(n_crops_per_side)
|
||||
]
|
||||
crop_box_y0 = [
|
||||
int((crop_h - overlap) * i) for i in range(n_crops_per_side)
|
||||
]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor,
|
||||
crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int,
|
||||
orig_w: int) -> torch.Tensor:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float,
|
||||
mode: str) -> Tuple[np.ndarray, bool]:
|
||||
"""Removes small disconnected regions and holes in a mask.
|
||||
|
||||
Returns the mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ['holes', 'islands']
|
||||
correct_holes = mode == 'holes'
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(
|
||||
working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from pycocotools import mask as mask_utils # type: ignore
|
||||
|
||||
h, w = uncompressed_rle['size']
|
||||
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
||||
rle['counts'] = rle['counts'].decode(
|
||||
'utf-8') # Necessary to serialize with json
|
||||
return rle
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculates boxes in XYXY format around masks.
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the
|
||||
output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(
|
||||
h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(
|
||||
w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges],
|
||||
dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
110
projects/sam_inference_demo/sam/utils/transforms.py
Normal file
110
projects/sam_inference_demo/sam/utils/transforms.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torchvision.transforms.functional import resize # type: ignore
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeLongestSide:
|
||||
"""Resizes images to longest side 'target_length', as well as provides
|
||||
methods for resizing coordinates and boxes.
|
||||
|
||||
Provides methods for transforming both numpy array and batched torch
|
||||
tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, target_length: int) -> None:
|
||||
self.target_length = target_length
|
||||
|
||||
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
||||
"""Expects a numpy array with shape HxWxC in uint8 format."""
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1],
|
||||
self.target_length)
|
||||
return np.array(resize(to_pil_image(image), target_size))
|
||||
|
||||
def apply_coords(self, coords: np.ndarray,
|
||||
original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""Expects a numpy array of length 2 in the final dimension.
|
||||
|
||||
Requires the original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(original_size[0],
|
||||
original_size[1],
|
||||
self.target_length)
|
||||
coords = deepcopy(coords).astype(float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes(self, boxes: np.ndarray,
|
||||
original_size: Tuple[int, ...]) -> np.ndarray:
|
||||
"""Expects a numpy array shape Bx4.
|
||||
|
||||
Requires the original image size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
||||
"""Expects batched images with shape BxCxHxW and float format.
|
||||
|
||||
This transformation may not exactly match apply_image. apply_image is
|
||||
the transformation expected by the model.
|
||||
"""
|
||||
# Expects an image in BCHW format. May not exactly match apply_image.
|
||||
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1],
|
||||
self.target_length)
|
||||
return F.interpolate(
|
||||
image,
|
||||
target_size,
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=True)
|
||||
|
||||
def apply_coords_torch(self, coords: torch.Tensor,
|
||||
original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""Expects a torch tensor with length 2 in the last dimension.
|
||||
|
||||
Requires the original image size in (H, W) format.
|
||||
"""
|
||||
old_h, old_w = original_size
|
||||
new_h, new_w = self.get_preprocess_shape(original_size[0],
|
||||
original_size[1],
|
||||
self.target_length)
|
||||
coords = deepcopy(coords).to(torch.float)
|
||||
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
||||
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
||||
return coords
|
||||
|
||||
def apply_boxes_torch(self, boxes: torch.Tensor,
|
||||
original_size: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""Expects a torch tensor with shape Bx4.
|
||||
|
||||
Requires the original image size in (H, W) format.
|
||||
"""
|
||||
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
||||
return boxes.reshape(-1, 4)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocess_shape(oldh: int, oldw: int,
|
||||
long_side_length: int) -> Tuple[int, int]:
|
||||
"""Compute the output size given input size and target long side
|
||||
length."""
|
||||
scale = long_side_length * 1.0 / max(oldh, oldw)
|
||||
newh, neww = oldh * scale, oldw * scale
|
||||
neww = int(neww + 0.5)
|
||||
newh = int(newh + 0.5)
|
||||
return (newh, neww)
|
198
projects/sam_inference_demo/sam_image_demo.ipynb
Normal file
198
projects/sam_inference_demo/sam_image_demo.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user