mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023). ## Modification Support CAT-Seg open-vocabulary semantic segmentation (CVPR2023). - [x] Support CAT-Seg model training. - [x] CLIP model based `backbone` (R101 & Swin-B), aggregation layers based `neck`, and `decoder` head. - [x] Provide customized coco-stuff164k_384x384 training configs. - [x] Language model supports for `open vocabulary` (OV) tasks. - [x] Support CLIP-based pretrained language model (LM) inference. - [x] Add commonly used prompts templates. - [x] Add README tutorials. - [x] Add zero-shot testing scripts. **Working on the following tasks.** - [x] Add unit test. ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
764 lines
28 KiB
Python
764 lines
28 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_norm_layer
|
|
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
|
from mmengine.model import BaseModule
|
|
from mmengine.utils import to_2tuple
|
|
|
|
from mmseg.registry import MODELS
|
|
from ..utils import FullAttention, LinearAttention
|
|
|
|
|
|
class AGWindowMSA(BaseModule):
|
|
"""Appearance Guidance Window based multi-head self-attention (W-MSA)
|
|
module with relative position bias.
|
|
|
|
Args:
|
|
embed_dims (int): Number of input channels.
|
|
appearance_dims (int): Number of appearance guidance feature channels.
|
|
num_heads (int): Number of attention heads.
|
|
window_size (tuple[int]): The height and width of the window.
|
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
|
Default: True.
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Default: None.
|
|
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
|
Default: 0.0
|
|
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
|
init_cfg (dict | None, optional): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
window_size,
|
|
qkv_bias=True,
|
|
qk_scale=None,
|
|
attn_drop_rate=0.,
|
|
proj_drop_rate=0.,
|
|
init_cfg=None):
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.embed_dims = embed_dims
|
|
self.appearance_dims = appearance_dims
|
|
self.window_size = window_size # Wh, Ww
|
|
self.num_heads = num_heads
|
|
head_embed_dims = embed_dims // num_heads
|
|
self.scale = qk_scale or head_embed_dims**-0.5
|
|
|
|
# About 2x faster than original impl
|
|
Wh, Ww = self.window_size
|
|
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
|
rel_position_index = rel_index_coords + rel_index_coords.T
|
|
rel_position_index = rel_position_index.flip(1).contiguous()
|
|
self.register_buffer('relative_position_index', rel_position_index)
|
|
|
|
self.qk = nn.Linear(
|
|
embed_dims + appearance_dims, embed_dims * 2, bias=qkv_bias)
|
|
self.v = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop_rate)
|
|
self.proj = nn.Linear(embed_dims, embed_dims)
|
|
self.proj_drop = nn.Dropout(proj_drop_rate)
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward(self, x, mask=None):
|
|
"""
|
|
Args:
|
|
x (tensor): input features with shape of (num_windows*B, N, C),
|
|
C = embed_dims + appearance_dims.
|
|
mask (tensor | None, Optional): mask with shape of (num_windows,
|
|
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
|
"""
|
|
B, N, _ = x.shape
|
|
qk = self.qk(x).reshape(B, N, 2, self.num_heads,
|
|
self.embed_dims // self.num_heads).permute(
|
|
2, 0, 3, 1,
|
|
4) # 2 B NUM_HEADS N embed_dims//NUM_HEADS
|
|
v = self.v(x[:, :, :self.embed_dims]).reshape(
|
|
B, N, self.num_heads, self.embed_dims // self.num_heads).permute(
|
|
0, 2, 1, 3) # B NUM_HEADS N embed_dims//NUM_HEADS
|
|
# make torchscript happy (cannot use tensor as tuple)
|
|
q, k = qk[0], qk[1]
|
|
|
|
q = q * self.scale
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
if mask is not None:
|
|
nW = mask.shape[0]
|
|
attn = attn.view(B // nW, nW, self.num_heads, N,
|
|
N) + mask.unsqueeze(1).unsqueeze(0)
|
|
attn = attn.view(-1, self.num_heads, N, N)
|
|
attn = self.softmax(attn)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def double_step_seq(step1, len1, step2, len2):
|
|
"""Double step sequence."""
|
|
seq1 = torch.arange(0, step1 * len1, step1)
|
|
seq2 = torch.arange(0, step2 * len2, step2)
|
|
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
|
|
|
|
|
class AGShiftWindowMSA(BaseModule):
|
|
"""Appearance Guidance Shifted Window Multihead Self-Attention Module.
|
|
|
|
Args:
|
|
embed_dims (int): Number of input channels.
|
|
appearance_dims (int): Number of appearance guidance channels
|
|
num_heads (int): Number of attention heads.
|
|
window_size (int): The height and width of the window.
|
|
shift_size (int, optional): The shift step of each window towards
|
|
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
|
Default: True
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Defaults: None.
|
|
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
|
Defaults: 0.
|
|
proj_drop_rate (float, optional): Dropout ratio of output.
|
|
Defaults: 0.
|
|
dropout_layer (dict, optional): The dropout_layer used before output.
|
|
Defaults: dict(type='DropPath', drop_prob=0.).
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
window_size,
|
|
shift_size=0,
|
|
qkv_bias=True,
|
|
qk_scale=None,
|
|
attn_drop_rate=0,
|
|
proj_drop_rate=0,
|
|
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
self.window_size = window_size
|
|
self.shift_size = shift_size
|
|
assert 0 <= self.shift_size < self.window_size
|
|
|
|
self.w_msa = AGWindowMSA(
|
|
embed_dims=embed_dims,
|
|
appearance_dims=appearance_dims,
|
|
num_heads=num_heads,
|
|
window_size=to_2tuple(window_size),
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
attn_drop_rate=attn_drop_rate,
|
|
proj_drop_rate=proj_drop_rate,
|
|
init_cfg=None)
|
|
|
|
self.drop = build_dropout(dropout_layer)
|
|
|
|
def forward(self, query, hw_shape):
|
|
"""
|
|
Args:
|
|
query: The input query.
|
|
hw_shape: The shape of the feature height and width.
|
|
"""
|
|
B, L, C = query.shape
|
|
H, W = hw_shape
|
|
assert L == H * W, 'input feature has wrong size'
|
|
query = query.view(B, H, W, C)
|
|
|
|
# pad feature maps to multiples of window size
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
|
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
|
H_pad, W_pad = query.shape[1], query.shape[2]
|
|
|
|
# cyclic shift
|
|
if self.shift_size > 0:
|
|
shifted_query = torch.roll(
|
|
query,
|
|
shifts=(-self.shift_size, -self.shift_size),
|
|
dims=(1, 2))
|
|
|
|
# calculate attention mask for SW-MSA
|
|
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
|
h_slices = (slice(0, -self.window_size),
|
|
slice(-self.window_size,
|
|
-self.shift_size), slice(-self.shift_size, None))
|
|
w_slices = (slice(0, -self.window_size),
|
|
slice(-self.window_size,
|
|
-self.shift_size), slice(-self.shift_size, None))
|
|
cnt = 0
|
|
for h in h_slices:
|
|
for w in w_slices:
|
|
img_mask[:, h, w, :] = cnt
|
|
cnt += 1
|
|
|
|
# nW, window_size, window_size, 1
|
|
mask_windows = self.window_partition(img_mask)
|
|
mask_windows = mask_windows.view(
|
|
-1, self.window_size * self.window_size)
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
|
float(-100.0)).masked_fill(
|
|
attn_mask == 0, float(0.0))
|
|
else:
|
|
shifted_query = query
|
|
attn_mask = None
|
|
|
|
# nW*B, window_size, window_size, C
|
|
query_windows = self.window_partition(shifted_query)
|
|
# nW*B, window_size*window_size, C
|
|
query_windows = query_windows.view(-1, self.window_size**2, C)
|
|
|
|
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
|
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
|
|
|
# merge windows
|
|
attn_windows = attn_windows.view(-1, self.window_size,
|
|
self.window_size,
|
|
self.w_msa.embed_dims)
|
|
|
|
# B H' W' self.w_msa.embed_dims
|
|
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
|
|
# reverse cyclic shift
|
|
if self.shift_size > 0:
|
|
x = torch.roll(
|
|
shifted_x,
|
|
shifts=(self.shift_size, self.shift_size),
|
|
dims=(1, 2))
|
|
else:
|
|
x = shifted_x
|
|
|
|
if pad_r > 0 or pad_b:
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
x = x.view(B, H * W, self.w_msa.embed_dims)
|
|
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
def window_reverse(self, windows, H, W):
|
|
"""
|
|
Args:
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
H (int): Height of image
|
|
W (int): Width of image
|
|
Returns:
|
|
x: (B, H, W, C)
|
|
"""
|
|
window_size = self.window_size
|
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
x = windows.view(B, H // window_size, W // window_size, window_size,
|
|
window_size, -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
return x
|
|
|
|
def window_partition(self, x):
|
|
"""
|
|
Args:
|
|
x: (B, H, W, C)
|
|
Returns:
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
"""
|
|
B, H, W, C = x.shape
|
|
window_size = self.window_size
|
|
x = x.view(B, H // window_size, window_size, W // window_size,
|
|
window_size, C)
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
windows = windows.view(-1, window_size, window_size, C)
|
|
return windows
|
|
|
|
|
|
class AGSwinBlock(BaseModule):
|
|
"""Appearance Guidance Swin Transformer Block.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension.
|
|
appearance_dims (int): The appearance guidance dimension.
|
|
num_heads (int): Parallel attention heads.
|
|
mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
|
|
for FFNs.
|
|
window_size (int, optional): The local window scale.
|
|
Default: 7.
|
|
shift (bool, optional): whether to shift window or not.
|
|
Default False.
|
|
qkv_bias (bool, optional): enable bias for qkv if True.
|
|
Default: True.
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Default: None.
|
|
drop_rate (float, optional): Dropout rate. Default: 0.
|
|
attn_drop_rate (float, optional): Attention dropout rate.
|
|
Default: 0.
|
|
drop_path_rate (float, optional): Stochastic depth rate.
|
|
Default: 0.
|
|
act_cfg (dict, optional): The config dict of activation function.
|
|
Default: dict(type='GELU').
|
|
norm_cfg (dict, optional): The config dict of normalization.
|
|
Default: dict(type='LN').
|
|
init_cfg (dict | list | None, optional): The init config.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
mlp_ratios=4,
|
|
window_size=7,
|
|
shift=False,
|
|
qkv_bias=True,
|
|
qk_scale=None,
|
|
drop_rate=0.,
|
|
attn_drop_rate=0.,
|
|
drop_path_rate=0.,
|
|
act_cfg=dict(type='GELU'),
|
|
norm_cfg=dict(type='LN'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
self.attn = AGShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
appearance_dims=appearance_dims,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
shift_size=window_size // 2 if shift else 0,
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
attn_drop_rate=attn_drop_rate,
|
|
proj_drop_rate=drop_rate,
|
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
|
init_cfg=None)
|
|
|
|
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
|
self.ffn = FFN(
|
|
embed_dims=embed_dims,
|
|
feedforward_channels=embed_dims * mlp_ratios,
|
|
num_fcs=2,
|
|
ffn_drop=drop_rate,
|
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
|
act_cfg=act_cfg,
|
|
add_identity=True,
|
|
init_cfg=None)
|
|
|
|
def forward(self, inputs, hw_shape):
|
|
"""
|
|
Args:
|
|
inputs (list[Tensor]): appearance_guidance (B, H, W, C);
|
|
x (B, L, C)
|
|
hw_shape (tuple[int]): shape of feature.
|
|
"""
|
|
x, appearance_guidance = inputs
|
|
B, L, C = x.shape
|
|
H, W = hw_shape
|
|
assert L == H * W, 'input feature has wrong size'
|
|
|
|
identity = x
|
|
x = self.norm1(x)
|
|
|
|
# appearance guidance
|
|
x = x.view(B, H, W, C)
|
|
if appearance_guidance is not None:
|
|
x = torch.cat([x, appearance_guidance], dim=-1).flatten(1, 2)
|
|
|
|
x = self.attn(x, hw_shape)
|
|
|
|
x = x + identity
|
|
|
|
identity = x
|
|
x = self.norm2(x)
|
|
x = self.ffn(x, identity=identity)
|
|
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SpatialAggregateLayer(BaseModule):
|
|
"""Spatial aggregation layer of CAT-Seg.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension.
|
|
appearance_dims (int): The appearance guidance dimension.
|
|
num_heads (int): Parallel attention heads.
|
|
mlp_ratios (int): The hidden dimension ratio w.r.t. embed_dims
|
|
for FFNs.
|
|
window_size (int, optional): The local window scale. Default: 7.
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
head_dim ** -0.5 if set. Default: None.
|
|
init_cfg (dict | list | None, optional): The init config.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
mlp_ratios,
|
|
window_size=7,
|
|
qk_scale=None,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.block_1 = AGSwinBlock(
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
mlp_ratios,
|
|
window_size=window_size,
|
|
shift=False,
|
|
qk_scale=qk_scale)
|
|
self.block_2 = AGSwinBlock(
|
|
embed_dims,
|
|
appearance_dims,
|
|
num_heads,
|
|
mlp_ratios,
|
|
window_size=window_size,
|
|
shift=True,
|
|
qk_scale=qk_scale)
|
|
self.guidance_norm = nn.LayerNorm(
|
|
appearance_dims) if appearance_dims > 0 else None
|
|
|
|
def forward(self, x, appearance_guidance):
|
|
"""
|
|
Args:
|
|
x (torch.Tensor): B C T H W.
|
|
appearance_guidance (torch.Tensor): B C H W.
|
|
"""
|
|
B, C, T, H, W = x.shape
|
|
x = x.permute(0, 2, 3, 4, 1).flatten(0, 1).flatten(1, 2) # BT, HW, C
|
|
if appearance_guidance is not None:
|
|
appearance_guidance = appearance_guidance.repeat(
|
|
T, 1, 1, 1).permute(0, 2, 3, 1) # BT, HW, C
|
|
appearance_guidance = self.guidance_norm(appearance_guidance)
|
|
else:
|
|
assert self.appearance_dims == 0
|
|
x = self.block_1((x, appearance_guidance), (H, W))
|
|
x = self.block_2((x, appearance_guidance), (H, W))
|
|
x = x.transpose(1, 2).reshape(B, T, C, -1)
|
|
x = x.transpose(1, 2).reshape(B, C, T, H, W)
|
|
return x
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
"""Attention layer for ClassAggregration of CAT-Seg.
|
|
|
|
Source: https://github.com/KU-CVLAB/CAT-Seg/blob/main/cat_seg/modeling/transformer/model.py#L310 # noqa
|
|
"""
|
|
|
|
def __init__(self,
|
|
hidden_dim,
|
|
guidance_dim,
|
|
nheads=8,
|
|
attention_type='linear'):
|
|
super().__init__()
|
|
self.nheads = nheads
|
|
self.q = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
|
|
self.k = nn.Linear(hidden_dim + guidance_dim, hidden_dim)
|
|
self.v = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
if attention_type == 'linear':
|
|
self.attention = LinearAttention()
|
|
elif attention_type == 'full':
|
|
self.attention = FullAttention()
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def forward(self, x, guidance=None):
|
|
"""
|
|
Args:
|
|
x: B*H_p*W_p, T, C
|
|
guidance: B*H_p*W_p, T, C
|
|
"""
|
|
B, L, _ = x.shape
|
|
q = self.q(torch.cat([x, guidance],
|
|
dim=-1)) if guidance is not None else self.q(x)
|
|
k = self.k(torch.cat([x, guidance],
|
|
dim=-1)) if guidance is not None else self.k(x)
|
|
v = self.v(x)
|
|
|
|
q = q.reshape(B, L, self.nheads, -1)
|
|
k = k.reshape(B, L, self.nheads, -1)
|
|
v = v.reshape(B, L, self.nheads, -1)
|
|
|
|
out = self.attention(q, k, v)
|
|
out = out.reshape(B, L, -1)
|
|
return out
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ClassAggregateLayer(BaseModule):
|
|
"""Class aggregation layer of CAT-Seg.
|
|
|
|
Args:
|
|
hidden_dims (int): The feature dimension.
|
|
guidance_dims (int): The appearance guidance dimension.
|
|
num_heads (int): Parallel attention heads.
|
|
attention_type (str): Type of attention layer. Default: 'linear'.
|
|
pooling_size (tuple[int] | list[int]): Pooling size.
|
|
init_cfg (dict | list | None, optional): The init config.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_dims=64,
|
|
guidance_dims=64,
|
|
num_heads=8,
|
|
attention_type='linear',
|
|
pooling_size=(4, 4),
|
|
init_cfg=None,
|
|
):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.pool = nn.AvgPool2d(pooling_size)
|
|
self.attention = AttentionLayer(
|
|
hidden_dims,
|
|
guidance_dims,
|
|
nheads=num_heads,
|
|
attention_type=attention_type)
|
|
self.MLP = FFN(
|
|
embed_dims=hidden_dims,
|
|
feedforward_channels=hidden_dims * 4,
|
|
num_fcs=2)
|
|
self.norm1 = nn.LayerNorm(hidden_dims)
|
|
self.norm2 = nn.LayerNorm(hidden_dims)
|
|
|
|
def pool_features(self, x):
|
|
"""Intermediate pooling layer for computational efficiency.
|
|
|
|
Args:
|
|
x: B, C, T, H, W
|
|
"""
|
|
B, C, T, H, W = x.shape
|
|
x = x.transpose(1, 2).reshape(-1, C, H, W)
|
|
x = self.pool(x)
|
|
*_, H_, W_ = x.shape
|
|
x = x.reshape(B, T, C, H_, W_).transpose(1, 2)
|
|
return x
|
|
|
|
def forward(self, x, guidance):
|
|
"""
|
|
Args:
|
|
x: B, C, T, H, W
|
|
guidance: B, T, C
|
|
"""
|
|
B, C, T, H, W = x.size()
|
|
x_pool = self.pool_features(x)
|
|
*_, H_pool, W_pool = x_pool.size()
|
|
|
|
x_pool = x_pool.permute(0, 3, 4, 2, 1).reshape(-1, T, C)
|
|
# B*H_p*W_p T C
|
|
if guidance is not None:
|
|
guidance = guidance.repeat(H_pool * W_pool, 1, 1)
|
|
|
|
x_pool = x_pool + self.attention(self.norm1(x_pool),
|
|
guidance) # Attention
|
|
x_pool = x_pool + self.MLP(self.norm2(x_pool)) # MLP
|
|
|
|
x_pool = x_pool.reshape(B, H_pool * W_pool, T,
|
|
C).permute(0, 2, 3, 1).reshape(
|
|
B, T, C, H_pool,
|
|
W_pool).flatten(0, 1) # BT C H_p W_p
|
|
x_pool = F.interpolate(
|
|
x_pool, size=(H, W), mode='bilinear', align_corners=True)
|
|
x_pool = x_pool.reshape(B, T, C, H, W).transpose(1, 2) # B C T H W
|
|
x = x + x_pool # Residual
|
|
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class AggregatorLayer(BaseModule):
|
|
"""Single Aggregator Layer of CAT-Seg."""
|
|
|
|
def __init__(self,
|
|
embed_dims=64,
|
|
text_guidance_dims=512,
|
|
appearance_guidance_dims=512,
|
|
num_heads=4,
|
|
mlp_ratios=4,
|
|
window_size=7,
|
|
attention_type='linear',
|
|
pooling_size=(2, 2),
|
|
init_cfg=None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.spatial_agg = SpatialAggregateLayer(
|
|
embed_dims,
|
|
appearance_guidance_dims,
|
|
num_heads=num_heads,
|
|
mlp_ratios=mlp_ratios,
|
|
window_size=window_size)
|
|
self.class_agg = ClassAggregateLayer(
|
|
embed_dims,
|
|
text_guidance_dims,
|
|
num_heads=num_heads,
|
|
attention_type=attention_type,
|
|
pooling_size=pooling_size)
|
|
|
|
def forward(self, x, appearance_guidance, text_guidance):
|
|
"""
|
|
Args:
|
|
x: B C T H W
|
|
"""
|
|
x = self.spatial_agg(x, appearance_guidance)
|
|
x = self.class_agg(x, text_guidance)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CATSegAggregator(BaseModule):
|
|
"""CATSeg Aggregator.
|
|
|
|
This Aggregator is the mmseg implementation of
|
|
`CAT-Seg <https://arxiv.org/abs/2303.11797>`_.
|
|
|
|
Args:
|
|
text_guidance_dim (int): Text guidance dimensions. Default: 512.
|
|
text_guidance_proj_dim (int): Text guidance projection dimensions.
|
|
Default: 128.
|
|
appearance_guidance_dim (int): Appearance guidance dimensions.
|
|
Default: 512.
|
|
appearance_guidance_proj_dim (int): Appearance guidance projection
|
|
dimensions. Default: 128.
|
|
num_layers (int): Aggregator layer number. Default: 4.
|
|
num_heads (int): Attention layer head number. Default: 4.
|
|
embed_dims (int): Input feature dimensions. Default: 128.
|
|
pooling_size (tuple | list): Pooling size of the class aggregator
|
|
layer. Default: (6, 6).
|
|
mlp_ratios (int): The hidden dimension ratio w.r.t. input dimension.
|
|
Default: 4.
|
|
window_size (int): Swin block window size. Default:12.
|
|
attention_type (str): Attention type of class aggregator layer.
|
|
Default:'linear'.
|
|
prompt_channel (int): Prompt channels. Default: 80.
|
|
"""
|
|
|
|
def __init__(self,
|
|
text_guidance_dim=512,
|
|
text_guidance_proj_dim=128,
|
|
appearance_guidance_dim=512,
|
|
appearance_guidance_proj_dim=128,
|
|
num_layers=4,
|
|
num_heads=4,
|
|
embed_dims=128,
|
|
pooling_size=(6, 6),
|
|
mlp_ratios=4,
|
|
window_size=12,
|
|
attention_type='linear',
|
|
prompt_channel=80,
|
|
**kwargs):
|
|
super().__init__(**kwargs)
|
|
self.num_layers = num_layers
|
|
self.embed_dims = embed_dims
|
|
|
|
self.layers = nn.ModuleList([
|
|
AggregatorLayer(
|
|
embed_dims=embed_dims,
|
|
text_guidance_dims=text_guidance_proj_dim,
|
|
appearance_guidance_dims=appearance_guidance_proj_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratios=mlp_ratios,
|
|
window_size=window_size,
|
|
attention_type=attention_type,
|
|
pooling_size=pooling_size) for _ in range(num_layers)
|
|
])
|
|
|
|
self.conv1 = nn.Conv2d(
|
|
prompt_channel, embed_dims, kernel_size=7, stride=1, padding=3)
|
|
|
|
self.guidance_projection = nn.Sequential(
|
|
nn.Conv2d(
|
|
appearance_guidance_dim,
|
|
appearance_guidance_proj_dim,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1),
|
|
nn.ReLU(),
|
|
) if appearance_guidance_dim > 0 else None
|
|
|
|
self.text_guidance_projection = nn.Sequential(
|
|
nn.Linear(text_guidance_dim, text_guidance_proj_dim),
|
|
nn.ReLU(),
|
|
) if text_guidance_dim > 0 else None
|
|
|
|
def feature_map(self, img_feats, text_feats):
|
|
"""Concatenation type cost volume.
|
|
|
|
For ablation study of cost volume type.
|
|
"""
|
|
img_feats = F.normalize(img_feats, dim=1) # B C H W
|
|
img_feats = img_feats.unsqueeze(2).repeat(1, 1, text_feats.shape[1], 1,
|
|
1)
|
|
text_feats = F.normalize(text_feats, dim=-1) # B T P C
|
|
text_feats = text_feats.mean(dim=-2)
|
|
text_feats = F.normalize(text_feats, dim=-1) # B T C
|
|
text_feats = text_feats.unsqueeze(-1).unsqueeze(-1).repeat(
|
|
1, 1, 1, img_feats.shape[-2], img_feats.shape[-1]).transpose(1, 2)
|
|
return torch.cat((img_feats, text_feats), dim=1) # B 2C T H W
|
|
|
|
def correlation(self, img_feats, text_feats):
|
|
"""Correlation of image features and text features."""
|
|
img_feats = F.normalize(img_feats, dim=1) # B C H W
|
|
text_feats = F.normalize(text_feats, dim=-1) # B T P C
|
|
corr = torch.einsum('bchw, btpc -> bpthw', img_feats, text_feats)
|
|
return corr
|
|
|
|
def corr_embed(self, x):
|
|
"""Correlation embeddings encoding."""
|
|
B = x.shape[0]
|
|
corr_embed = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
|
corr_embed = self.conv1(corr_embed)
|
|
corr_embed = corr_embed.reshape(B, -1, self.embed_dims, x.shape[-2],
|
|
x.shape[-1]).transpose(1, 2)
|
|
return corr_embed
|
|
|
|
def forward(self, inputs):
|
|
"""
|
|
Args:
|
|
inputs (dict): including the following keys,
|
|
'appearance_feat': list[torch.Tensor], w.r.t. out_indices of
|
|
`self.feature_extractor`.
|
|
'clip_text_feat': the text feature extracted by clip text
|
|
encoder.
|
|
'clip_text_feat_test': the text feature extracted by clip text
|
|
encoder for testing.
|
|
'clip_img_feat': the image feature extracted clip image
|
|
encoder.
|
|
"""
|
|
img_feats = inputs['clip_img_feat']
|
|
B = img_feats.size(0)
|
|
appearance_guidance = inputs[
|
|
'appearance_feat'][::-1] # order (out_indices) 2, 1, 0
|
|
text_feats = inputs['clip_text_feat'] if self.training else inputs[
|
|
'clip_text_feat_test']
|
|
text_feats = text_feats.repeat(B, 1, 1, 1)
|
|
|
|
corr = self.correlation(img_feats, text_feats)
|
|
# corr = self.feature_map(img_feats, text_feats)
|
|
corr_embed = self.corr_embed(corr)
|
|
|
|
projected_guidance, projected_text_guidance = None, None
|
|
|
|
if self.guidance_projection is not None:
|
|
projected_guidance = self.guidance_projection(
|
|
appearance_guidance[0])
|
|
|
|
if self.text_guidance_projection is not None:
|
|
text_feats = text_feats.mean(dim=-2)
|
|
text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
|
|
projected_text_guidance = self.text_guidance_projection(text_feats)
|
|
|
|
for layer in self.layers:
|
|
corr_embed = layer(corr_embed, projected_guidance,
|
|
projected_text_guidance)
|
|
|
|
return dict(
|
|
corr_embed=corr_embed, appearance_feats=appearance_guidance[1:])
|