[Feature] Add ViT of SAM (#1476)
* add vit of sam * update * update * add ut * update ut * remove num_classes * support dynamic input * add ut * add comments * update utpull/1488/head
parent
e80418a424
commit
0826df8963
|
@ -195,8 +195,12 @@ def add_usage(metafile):
|
|||
if train_model:
|
||||
template = TRAIN_TEST_TEMPLATE
|
||||
inputs['train_config'] = train_model[0].config
|
||||
else:
|
||||
elif len(filter_models_by_task(models, task='any')) > 0:
|
||||
template = TEST_ONLY_TEMPLATE
|
||||
else:
|
||||
content.append('\n<!-- [TABS-END] -->\n')
|
||||
return '\n'.join(content)
|
||||
|
||||
test_model = filter_models_by_task(models, task='any')[0]
|
||||
inputs['test_config'] = test_model.config
|
||||
inputs['test_weights'] = test_model.weights
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# SAM
|
||||
|
||||
> [Segment Anything](https://arxiv.org/abs/2304.02643)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
We introduce the Segment Anything (SA) project: a new task, model, and dataset for image segmentation. Using our efficient model in a data collection loop, we built the largest segmentation dataset to date (by far), with over 1 billionmasks on 11M licensed and privacy respecting images. The model is designed and trained to be promptable, so it can transfer zero-shot to new image distributions and tasks. We evaluate its capabilities on numerous tasks and find that its zero-shot performance is impressive – often competitive with or even superior to prior fully supervised results. We are releasing the Segment Anything Model (SAM) and corresponding dataset (SA-1B) of 1B masks and 11M images at https://segment-anything.com to foster research into foundation models for computer vision.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/36138628/231106092-261ff035-dd3b-4a8b-b2e7-e91f195090a1.png" width="100%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
**Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('vit-base-p16_sam-pre_3rdparty_sa1b-1024px', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 1024, 1024)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
# To extract features.
|
||||
feats = model.extract_feat(inputs)
|
||||
print(type(feats))
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :--------------------------------------------- | :--------: | :-------: | :-------------------------------------: | :----------------------------------------------------------------------------------------------: |
|
||||
| `vit-base-p16_sam-pre_3rdparty_sa1b-1024px`\* | 89.67 | 486.00 | [config](vit-base-p16_sam_headless.py) | [model](https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-base-p16_sam-pre_3rdparty_sa1b-1024px_20230411-2320f9cc.pth) |
|
||||
| `vit-large-p16_sam-pre_3rdparty_sa1b-1024px`\* | 308.00 | 1494.00 | [config](vit-large-p16_sam_headless.py) | [model](https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-large-p16_sam-pre_3rdparty_sa1b-1024px_20230411-595feafd.pth) |
|
||||
| `vit-huge-p16_sam-pre_3rdparty_sa1b-1024px`\* | 637.00 | 2982.00 | [config](vit-huge-p16_sam_headless.py) | [model](https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-huge-p16_sam-pre_3rdparty_sa1b-1024px_20230411-3f13c653.pth) |
|
||||
|
||||
*Models with * are converted from the [official repo](https://github.com/facebookresearch/segment-anything/). The config files of these models are only for inference. We haven't reprodcue the training results.*
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{kirillov2023segany,
|
||||
title={Segment Anything},
|
||||
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
|
||||
journal={arXiv:2304.02643},
|
||||
year={2023}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,61 @@
|
|||
Collections:
|
||||
- Name: SAM
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Convolution
|
||||
- Dense Connections
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
Paper:
|
||||
Title: 'Segment Anything'
|
||||
URL: https://arxiv.org/abs/2304.02643
|
||||
README: configs/sam/README.md
|
||||
Code:
|
||||
URL: null
|
||||
Version: null
|
||||
|
||||
Models:
|
||||
- Name: vit-base-p16_sam-pre_3rdparty_sa1b-1024px
|
||||
Metadata:
|
||||
FLOPs: 486000000000
|
||||
Parameters: 89671000
|
||||
Training Data:
|
||||
- SA-1B
|
||||
In Collection: SAM
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-base-p16_sam-pre_3rdparty_sa1b-1024px_20230411-2320f9cc.pth
|
||||
Config: configs/sam/vit-base-p16_sam_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
||||
Code: https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
- Name: vit-large-p16_sam-pre_3rdparty_sa1b-1024px
|
||||
Metadata:
|
||||
FLOPs: 1494000000000
|
||||
Parameters: 308000000
|
||||
Training Data:
|
||||
- SA-1B
|
||||
In Collection: SAM
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-large-p16_sam-pre_3rdparty_sa1b-1024px_20230411-595feafd.pth
|
||||
Config: configs/sam/vit-large-p16_sam_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
||||
Code: https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
- Name: vit-huge-p16_sam-pre_3rdparty_sa1b-1024px
|
||||
Metadata:
|
||||
FLOPs: 2982000000000
|
||||
Parameters: 637000000
|
||||
Training Data:
|
||||
- SA-1B
|
||||
In Collection: SAM
|
||||
Results: null
|
||||
Weights: https://download.openmmlab.com/mmclassification/v1/vit_sam/vit-huge-p16_sam-pre_3rdparty_sa1b-1024px_20230411-3f13c653.pth
|
||||
Config: configs/sam/vit-huge-p16_sam_headless.py
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
Code: https://github.com/facebookresearch/segment-anything/
|
|
@ -0,0 +1,24 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ViTSAM',
|
||||
arch='base',
|
||||
img_size=1024,
|
||||
patch_size=16,
|
||||
out_channels=256,
|
||||
use_abs_pos=True,
|
||||
use_rel_pos=True,
|
||||
window_size=14,
|
||||
),
|
||||
neck=None,
|
||||
head=None,
|
||||
)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ViTSAM',
|
||||
arch='huge',
|
||||
img_size=1024,
|
||||
patch_size=16,
|
||||
out_channels=256,
|
||||
use_abs_pos=True,
|
||||
use_rel_pos=True,
|
||||
window_size=14,
|
||||
),
|
||||
neck=None,
|
||||
head=None,
|
||||
)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='ViTSAM',
|
||||
arch='large',
|
||||
img_size=1024,
|
||||
patch_size=16,
|
||||
out_channels=256,
|
||||
use_abs_pos=True,
|
||||
use_rel_pos=True,
|
||||
window_size=14,
|
||||
),
|
||||
neck=None,
|
||||
head=None,
|
||||
)
|
||||
|
||||
data_preprocessor = dict(
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
|
@ -187,6 +187,7 @@ Backbones
|
|||
VGG
|
||||
Vig
|
||||
VisionTransformer
|
||||
ViTSAM
|
||||
XCiT
|
||||
|
||||
.. module:: mmpretrain.models.necks
|
||||
|
|
|
@ -52,6 +52,7 @@ from .van import VAN
|
|||
from .vgg import VGG
|
||||
from .vig import PyramidVig, Vig
|
||||
from .vision_transformer import VisionTransformer
|
||||
from .vit_sam import ViTSAM
|
||||
from .xcit import XCiT
|
||||
|
||||
__all__ = [
|
||||
|
@ -116,4 +117,5 @@ __all__ = [
|
|||
'Vig',
|
||||
'PyramidVig',
|
||||
'XCiT',
|
||||
'ViTSAM',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,572 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import LayerNorm2d, build_norm_layer, resize_pos_embed, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor,
|
||||
window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""Partition into non-overlapping windows with padding if needed.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tokens with [B, H, W, C].
|
||||
window_size (int): Window size.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
- ``windows``: Windows after partition with
|
||||
[B * num_windows, window_size, window_size, C].
|
||||
- ``(Hp, Wp)``: Padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size,
|
||||
window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int,
|
||||
pad_hw: Tuple[int, int],
|
||||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Window unpartition into original sequences and removing padding.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tokens with
|
||||
[B * num_windows, window_size, window_size, C].
|
||||
window_size (int): Window size.
|
||||
pad_hw (tuple): Padded height and width (Hp, Wp).
|
||||
hw (tuple): Original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int,
|
||||
rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""Get relative positional embeddings according to the relative positions
|
||||
of query and key sizes.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
Args:
|
||||
q_size (int): Size of query q.
|
||||
k_size (int): Size of key k.
|
||||
rel_pos (torch.Tensor): Relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Extracted positional embeddings according to relative
|
||||
positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode='linear',
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1,
|
||||
max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords -
|
||||
k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Borrowed from https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
|
||||
|
||||
Args:
|
||||
attn (torch.Tensor): Attention map.
|
||||
q (torch.Tensor): Query q in the attention layer with shape
|
||||
(B, q_h * q_w, C).
|
||||
rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for
|
||||
height axis.
|
||||
rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for
|
||||
width axis.
|
||||
q_size (tuple): Spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (tuple): Spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
|
||||
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] +
|
||||
rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings.
|
||||
|
||||
Borrowed from https://github.com/facebookresearch/segment-anything/
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
use_rel_pos (bool):Whether to use relative position embedding.
|
||||
Defaults to False.
|
||||
input_size (int, optional): Input resolution for calculating the
|
||||
relative positional parameter size. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_embed_dims = embed_dims // num_heads
|
||||
self.scale = head_embed_dims**-0.5
|
||||
|
||||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (input_size is not None), \
|
||||
'Input size must be provided if using relative position embed.'
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(
|
||||
torch.zeros(2 * input_size[0] - 1, head_embed_dims))
|
||||
self.rel_pos_w = nn.Parameter(
|
||||
torch.zeros(2 * input_size[1] - 1, head_embed_dims))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
|
||||
-1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h,
|
||||
self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W,
|
||||
-1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
"""Encoder layer with window attention in Vision Transformer.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension
|
||||
num_heads (int): Parallel attention heads
|
||||
feedforward_channels (int): The hidden dimension for FFNs
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaluts to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
use_rel_pos (bool):Whether to use relative position embedding.
|
||||
Defaults to False.
|
||||
window_size (int): Window size for window attention. Defaults to 0.
|
||||
input_size (int, optional): Input resolution for calculating the
|
||||
relative positional parameter size. Defaults to None.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels: int,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = True,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
use_rel_pos: bool = False,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
init_cfg=None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.window_size = window_size
|
||||
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.attn = Attention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
input_size=input_size if window_size == 0 else
|
||||
(window_size, window_size),
|
||||
)
|
||||
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return self.ln1
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return self.ln2
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.ln1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
x = shortcut + x
|
||||
|
||||
x = self.ffn(self.ln2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ViTSAM(BaseBackbone):
|
||||
"""Vision Transformer as image encoder used in SAM.
|
||||
|
||||
A PyTorch implement of backbone: `Segment Anything
|
||||
<https://arxiv.org/abs/2304.02643>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'base', 'large', 'huge'. If use dict, it should have
|
||||
below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
- **global_attn_indexes** (int): The index of layers with global
|
||||
attention.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_channels (int): The num of output channels, if equal to 0, the
|
||||
channel reduction layer is disabled. Defaults to 256.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
use_abs_pos (bool): Whether to use absolute position embedding.
|
||||
Defaults to True.
|
||||
use_rel_pos (bool):Whether to use relative position embedding.
|
||||
Defaults to True.
|
||||
window_size (int): Window size for window attention. Defaults to 14.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
'embed_dims': 768,
|
||||
'num_layers': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 3072,
|
||||
'global_attn_indexes': [2, 5, 8, 11]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['l', 'large'], {
|
||||
'embed_dims': 1024,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096,
|
||||
'global_attn_indexes': [5, 11, 17, 23]
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['h', 'huge'], {
|
||||
'embed_dims': 1280,
|
||||
'num_layers': 32,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 5120,
|
||||
'global_attn_indexes': [7, 15, 23, 31]
|
||||
}),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'base',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 256,
|
||||
out_indices: int = -1,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
qkv_bias: bool = True,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = True,
|
||||
window_size: int = 14,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
frozen_stages: int = -1,
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
init_cfg: Optional[dict] = None):
|
||||
super().__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
||||
}
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.num_layers = self.arch_settings['num_layers']
|
||||
self.global_attn_indexes = self.arch_settings['global_attn_indexes']
|
||||
self.img_size = to_2tuple(img_size)
|
||||
|
||||
# Set patch embedding
|
||||
_patch_cfg = dict(
|
||||
in_channels=in_channels,
|
||||
input_size=img_size,
|
||||
embed_dims=self.embed_dims,
|
||||
conv_type='Conv2d',
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
_patch_cfg.update(patch_cfg)
|
||||
self.patch_embed = PatchEmbed(**_patch_cfg)
|
||||
self.patch_resolution = self.patch_embed.init_out_size
|
||||
# num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
self.use_abs_pos = use_abs_pos
|
||||
self.interpolate_mode = interpolate_mode
|
||||
if use_abs_pos:
|
||||
# Set position embedding
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, *self.patch_resolution, self.embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = self.num_layers + index
|
||||
assert 0 <= out_indices[i] <= self.num_layers, \
|
||||
f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
self.layers = ModuleList()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
layer_cfgs = [layer_cfgs] * self.num_layers
|
||||
for i in range(self.num_layers):
|
||||
_layer_cfg = dict(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.arch_settings['num_heads'],
|
||||
feedforward_channels=self.
|
||||
arch_settings['feedforward_channels'],
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
qkv_bias=qkv_bias,
|
||||
window_size=window_size
|
||||
if i not in self.global_attn_indexes else 0,
|
||||
input_size=self.patch_resolution,
|
||||
use_rel_pos=use_rel_pos,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
self.out_channels = out_channels
|
||||
if self.out_channels > 0:
|
||||
self.channel_reduction = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.embed_dims,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_channels, eps=1e-6),
|
||||
nn.Conv2d(
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_channels, eps=1e-6),
|
||||
)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
self.frozen_stages = frozen_stages
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _freeze_stages(self):
|
||||
# freeze position embedding
|
||||
if self.pos_embed is not None:
|
||||
self.pos_embed.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.drop_after_pos.eval()
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
x = x.view(B, patch_resolution[0], patch_resolution[1],
|
||||
self.embed_dims)
|
||||
|
||||
if self.use_abs_pos:
|
||||
# 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but
|
||||
# in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it
|
||||
# is flattened. Besides, ViTSAM doesn't have any extra token.
|
||||
resized_pos_embed = resize_pos_embed(
|
||||
self.pos_embed.flatten(1, 2),
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=0)
|
||||
x = x + resized_pos_embed.view(1, *patch_resolution,
|
||||
self.embed_dims)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
if self.out_channels > 0:
|
||||
x = self.channel_reduction(x.permute(0, 3, 1, 2))
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
|
@ -67,3 +67,4 @@ Import:
|
|||
- configs/maskfeat/metafile.yml
|
||||
- configs/milan/metafile.yml
|
||||
- configs/riformer/metafile.yml
|
||||
- configs/sam/metafile.yml
|
||||
|
|
|
@ -16,6 +16,7 @@ class Cfg:
|
|||
build: bool = True
|
||||
forward: bool = True
|
||||
backward: bool = True
|
||||
extract_feat: bool = True
|
||||
input_shape: tuple = (1, 3, 224, 224)
|
||||
|
||||
|
||||
|
@ -25,6 +26,10 @@ test_list = [
|
|||
Cfg(name='xcit-nano-12-p8_3rdparty-dist_in1k-384px',
|
||||
backbone=mmpretrain.models.XCiT,
|
||||
input_shape=(1, 3, 384, 384)),
|
||||
Cfg(name='vit-base-p16_sam-pre_3rdparty_sa1b-1024px',
|
||||
backbone=mmpretrain.models.ViTSAM,
|
||||
forward=False,
|
||||
backward=False),
|
||||
]
|
||||
|
||||
|
||||
|
@ -52,6 +57,14 @@ def test_forward(cfg: Cfg):
|
|||
outputs = model(inputs)
|
||||
assert outputs.shape == (1, cfg.num_classes)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg', test_list)
|
||||
def test_extract_feat(cfg: Cfg):
|
||||
if not cfg.extract_feat:
|
||||
return
|
||||
|
||||
model = get_model(cfg.name)
|
||||
inputs = torch.rand(*cfg.input_shape)
|
||||
feats = model.extract_feat(inputs)
|
||||
assert isinstance(feats, tuple)
|
||||
assert len(feats) == 1
|
||||
|
|
Loading…
Reference in New Issue