701 lines
26 KiB
Python
701 lines
26 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional, Sequence
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_activation_layer, build_norm_layer
|
|
from mmcv.cnn.bricks import DropPath
|
|
from mmcv.cnn.bricks.transformer import PatchEmbed
|
|
from mmengine.model import BaseModule, ModuleList
|
|
from mmengine.model.weight_init import trunc_normal_
|
|
from mmengine.utils import to_2tuple
|
|
|
|
from ..builder import BACKBONES
|
|
from ..utils import resize_pos_embed
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
def resize_decomposed_rel_pos(rel_pos, q_size, k_size):
|
|
"""Get relative positional embeddings according to the relative positions
|
|
of query and key sizes.
|
|
|
|
Args:
|
|
q_size (int): size of query q.
|
|
k_size (int): size of key k.
|
|
rel_pos (Tensor): relative position embeddings (L, C).
|
|
|
|
Returns:
|
|
Extracted positional embeddings according to relative positions.
|
|
"""
|
|
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
# Interpolate rel pos if needed.
|
|
if rel_pos.shape[0] != max_rel_dist:
|
|
# Interpolate rel pos.
|
|
resized = F.interpolate(
|
|
# (L, C) -> (1, C, L)
|
|
rel_pos.transpose(0, 1).unsqueeze(0),
|
|
size=max_rel_dist,
|
|
mode='linear',
|
|
)
|
|
# (1, C, L) -> (L, C)
|
|
resized = resized.squeeze(0).transpose(0, 1)
|
|
else:
|
|
resized = rel_pos
|
|
|
|
# Scale the coords with short length if shapes for q and k are different.
|
|
q_h_ratio = max(k_size / q_size, 1.0)
|
|
k_h_ratio = max(q_size / k_size, 1.0)
|
|
q_coords = torch.arange(q_size)[:, None] * q_h_ratio
|
|
k_coords = torch.arange(k_size)[None, :] * k_h_ratio
|
|
relative_coords = (q_coords - k_coords) + (k_size - 1) * k_h_ratio
|
|
|
|
return resized[relative_coords.long()]
|
|
|
|
|
|
def add_decomposed_rel_pos(attn,
|
|
q,
|
|
q_shape,
|
|
k_shape,
|
|
rel_pos_h,
|
|
rel_pos_w,
|
|
has_cls_token=False):
|
|
"""Spatial Relative Positional Embeddings."""
|
|
sp_idx = 1 if has_cls_token else 0
|
|
B, num_heads, _, C = q.shape
|
|
q_h, q_w = q_shape
|
|
k_h, k_w = k_shape
|
|
|
|
Rh = resize_decomposed_rel_pos(rel_pos_h, q_h, k_h)
|
|
Rw = resize_decomposed_rel_pos(rel_pos_w, q_w, k_w)
|
|
|
|
r_q = q[:, :, sp_idx:].reshape(B, num_heads, q_h, q_w, C)
|
|
rel_h = torch.einsum('byhwc,hkc->byhwk', r_q, Rh)
|
|
rel_w = torch.einsum('byhwc,wkc->byhwk', r_q, Rw)
|
|
rel_pos_embed = rel_h[:, :, :, :, :, None] + rel_w[:, :, :, :, None, :]
|
|
|
|
attn_map = attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
|
|
attn_map += rel_pos_embed
|
|
attn[:, :, sp_idx:, sp_idx:] = attn_map.view(B, -1, q_h * q_w, k_h * k_w)
|
|
|
|
return attn
|
|
|
|
|
|
class MLP(BaseModule):
|
|
"""Two-layer multilayer perceptron.
|
|
|
|
Comparing with :class:`mmcv.cnn.bricks.transformer.FFN`, this class allows
|
|
different input and output channel numbers.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
hidden_channels (int, optional): The number of hidden layer channels.
|
|
If None, same as the ``in_channels``. Defaults to None.
|
|
out_channels (int, optional): The number of output channels. If None,
|
|
same as the ``in_channels``. Defaults to None.
|
|
act_cfg (dict): The config of activation function.
|
|
Defaults to ``dict(type='GELU')``.
|
|
init_cfg (dict, optional): The config of weight initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
hidden_channels=None,
|
|
out_channels=None,
|
|
act_cfg=dict(type='GELU'),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
out_channels = out_channels or in_channels
|
|
hidden_channels = hidden_channels or in_channels
|
|
self.fc1 = nn.Linear(in_channels, hidden_channels)
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.fc2 = nn.Linear(hidden_channels, out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
|
|
def attention_pool(x: torch.Tensor,
|
|
pool: nn.Module,
|
|
in_size: tuple,
|
|
norm: Optional[nn.Module] = None):
|
|
"""Pooling the feature tokens.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor, should be with shape
|
|
``(B, num_heads, L, C)`` or ``(B, L, C)``.
|
|
pool (nn.Module): The pooling module.
|
|
in_size (Tuple[int]): The shape of the input feature map.
|
|
norm (nn.Module, optional): The normalization module.
|
|
Defaults to None.
|
|
"""
|
|
ndim = x.ndim
|
|
if ndim == 4:
|
|
B, num_heads, L, C = x.shape
|
|
elif ndim == 3:
|
|
num_heads = 1
|
|
B, L, C = x.shape
|
|
else:
|
|
raise RuntimeError(f'Unsupported input dimension {x.shape}')
|
|
|
|
H, W = in_size
|
|
assert L == H * W
|
|
|
|
# (B, num_heads, H*W, C) -> (B*num_heads, C, H, W)
|
|
x = x.reshape(B * num_heads, H, W, C).permute(0, 3, 1, 2).contiguous()
|
|
x = pool(x)
|
|
out_size = x.shape[-2:]
|
|
|
|
# (B*num_heads, C, H', W') -> (B, num_heads, H'*W', C)
|
|
x = x.reshape(B, num_heads, C, -1).transpose(2, 3)
|
|
|
|
if norm is not None:
|
|
x = norm(x)
|
|
|
|
if ndim == 3:
|
|
x = x.squeeze(1)
|
|
|
|
return x, out_size
|
|
|
|
|
|
class MultiScaleAttention(BaseModule):
|
|
"""Multiscale Multi-head Attention block.
|
|
|
|
Args:
|
|
in_dims (int): Number of input channels.
|
|
out_dims (int): Number of output channels.
|
|
num_heads (int): Number of attention heads.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key and
|
|
value. Defaults to True.
|
|
norm_cfg (dict): The config of normalization layers.
|
|
Defaults to ``dict(type='LN')``.
|
|
pool_kernel (tuple): kernel size for qkv pooling layers.
|
|
Defaults to (3, 3).
|
|
stride_q (int): stride size for q pooling layer. Defaults to 1.
|
|
stride_kv (int): stride size for kv pooling layer. Defaults to 1.
|
|
rel_pos_spatial (bool): Whether to enable the spatial relative
|
|
position embedding. Defaults to True.
|
|
residual_pooling (bool): Whether to enable the residual connection
|
|
after attention pooling. Defaults to True.
|
|
input_size (Tuple[int], optional): The input resolution, necessary
|
|
if enable the ``rel_pos_spatial``. Defaults to None.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative
|
|
positional parameters. Defaults to False.
|
|
init_cfg (dict, optional): The config of weight initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_dims,
|
|
out_dims,
|
|
num_heads,
|
|
qkv_bias=True,
|
|
norm_cfg=dict(type='LN'),
|
|
pool_kernel=(3, 3),
|
|
stride_q=1,
|
|
stride_kv=1,
|
|
rel_pos_spatial=False,
|
|
residual_pooling=True,
|
|
input_size=None,
|
|
rel_pos_zero_init=False,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.num_heads = num_heads
|
|
self.in_dims = in_dims
|
|
self.out_dims = out_dims
|
|
|
|
head_dim = out_dims // num_heads
|
|
self.scale = head_dim**-0.5
|
|
|
|
self.qkv = nn.Linear(in_dims, out_dims * 3, bias=qkv_bias)
|
|
self.proj = nn.Linear(out_dims, out_dims)
|
|
|
|
# qkv pooling
|
|
pool_padding = [k // 2 for k in pool_kernel]
|
|
pool_dims = out_dims // num_heads
|
|
|
|
def build_pooling(stride):
|
|
pool = nn.Conv2d(
|
|
pool_dims,
|
|
pool_dims,
|
|
pool_kernel,
|
|
stride=stride,
|
|
padding=pool_padding,
|
|
groups=pool_dims,
|
|
bias=False,
|
|
)
|
|
norm = build_norm_layer(norm_cfg, pool_dims)[1]
|
|
return pool, norm
|
|
|
|
self.pool_q, self.norm_q = build_pooling(stride_q)
|
|
self.pool_k, self.norm_k = build_pooling(stride_kv)
|
|
self.pool_v, self.norm_v = build_pooling(stride_kv)
|
|
|
|
self.residual_pooling = residual_pooling
|
|
|
|
self.rel_pos_spatial = rel_pos_spatial
|
|
self.rel_pos_zero_init = rel_pos_zero_init
|
|
if self.rel_pos_spatial:
|
|
# initialize relative positional embeddings
|
|
assert input_size[0] == input_size[1]
|
|
|
|
size = input_size[0]
|
|
rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1
|
|
self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim))
|
|
self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim))
|
|
|
|
def init_weights(self):
|
|
"""Weight initialization."""
|
|
super().init_weights()
|
|
|
|
if (isinstance(self.init_cfg, dict)
|
|
and self.init_cfg['type'] == 'Pretrained'):
|
|
# Suppress rel_pos_zero_init if use pretrained model.
|
|
return
|
|
|
|
if not self.rel_pos_zero_init:
|
|
trunc_normal_(self.rel_pos_h, std=0.02)
|
|
trunc_normal_(self.rel_pos_w, std=0.02)
|
|
|
|
def forward(self, x, in_size):
|
|
"""Forward the MultiScaleAttention."""
|
|
B, N, _ = x.shape # (B, H*W, C)
|
|
|
|
# qkv: (B, H*W, 3, num_heads, C)
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1)
|
|
# q, k, v: (B, num_heads, H*W, C)
|
|
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
|
|
|
|
q, q_shape = attention_pool(q, self.pool_q, in_size, norm=self.norm_q)
|
|
k, k_shape = attention_pool(k, self.pool_k, in_size, norm=self.norm_k)
|
|
v, v_shape = attention_pool(v, self.pool_v, in_size, norm=self.norm_v)
|
|
|
|
attn = (q * self.scale) @ k.transpose(-2, -1)
|
|
if self.rel_pos_spatial:
|
|
attn = add_decomposed_rel_pos(attn, q, q_shape, k_shape,
|
|
self.rel_pos_h, self.rel_pos_w)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
x = attn @ v
|
|
|
|
if self.residual_pooling:
|
|
x = x + q
|
|
|
|
# (B, num_heads, H'*W', C'//num_heads) -> (B, H'*W', C')
|
|
x = x.transpose(1, 2).reshape(B, -1, self.out_dims)
|
|
x = self.proj(x)
|
|
|
|
return x, q_shape
|
|
|
|
|
|
class MultiScaleBlock(BaseModule):
|
|
"""Multiscale Transformer blocks.
|
|
|
|
Args:
|
|
in_dims (int): Number of input channels.
|
|
out_dims (int): Number of output channels.
|
|
num_heads (int): Number of attention heads.
|
|
mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
|
|
Defaults to 4.0.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key and
|
|
value. Defaults to True.
|
|
drop_path (float): Stochastic depth rate. Defaults to 0.
|
|
norm_cfg (dict): The config of normalization layers.
|
|
Defaults to ``dict(type='LN')``.
|
|
act_cfg (dict): The config of activation function.
|
|
Defaults to ``dict(type='GELU')``.
|
|
qkv_pool_kernel (tuple): kernel size for qkv pooling layers.
|
|
Defaults to (3, 3).
|
|
stride_q (int): stride size for q pooling layer. Defaults to 1.
|
|
stride_kv (int): stride size for kv pooling layer. Defaults to 1.
|
|
rel_pos_spatial (bool): Whether to enable the spatial relative
|
|
position embedding. Defaults to True.
|
|
residual_pooling (bool): Whether to enable the residual connection
|
|
after attention pooling. Defaults to True.
|
|
dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
|
|
attention layers. If False, multiply it in MLP layers.
|
|
Defaults to True.
|
|
input_size (Tuple[int], optional): The input resolution, necessary
|
|
if enable the ``rel_pos_spatial``. Defaults to None.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative
|
|
positional parameters. Defaults to False.
|
|
init_cfg (dict, optional): The config of weight initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_dims,
|
|
out_dims,
|
|
num_heads,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=True,
|
|
drop_path=0.0,
|
|
norm_cfg=dict(type='LN'),
|
|
act_cfg=dict(type='GELU'),
|
|
qkv_pool_kernel=(3, 3),
|
|
stride_q=1,
|
|
stride_kv=1,
|
|
rel_pos_spatial=True,
|
|
residual_pooling=True,
|
|
dim_mul_in_attention=True,
|
|
input_size=None,
|
|
rel_pos_zero_init=False,
|
|
init_cfg=None,
|
|
):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.in_dims = in_dims
|
|
self.out_dims = out_dims
|
|
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1]
|
|
self.dim_mul_in_attention = dim_mul_in_attention
|
|
|
|
attn_dims = out_dims if dim_mul_in_attention else in_dims
|
|
self.attn = MultiScaleAttention(
|
|
in_dims,
|
|
attn_dims,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
norm_cfg=norm_cfg,
|
|
pool_kernel=qkv_pool_kernel,
|
|
stride_q=stride_q,
|
|
stride_kv=stride_kv,
|
|
rel_pos_spatial=rel_pos_spatial,
|
|
residual_pooling=residual_pooling,
|
|
input_size=input_size,
|
|
rel_pos_zero_init=rel_pos_zero_init)
|
|
self.drop_path = DropPath(
|
|
drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.norm2 = build_norm_layer(norm_cfg, attn_dims)[1]
|
|
|
|
self.mlp = MLP(
|
|
in_channels=attn_dims,
|
|
hidden_channels=int(attn_dims * mlp_ratio),
|
|
out_channels=out_dims,
|
|
act_cfg=act_cfg)
|
|
|
|
if in_dims != out_dims:
|
|
self.proj = nn.Linear(in_dims, out_dims)
|
|
else:
|
|
self.proj = None
|
|
|
|
if stride_q > 1:
|
|
kernel_skip = stride_q + 1
|
|
padding_skip = int(kernel_skip // 2)
|
|
self.pool_skip = nn.MaxPool2d(
|
|
kernel_skip, stride_q, padding_skip, ceil_mode=False)
|
|
|
|
if input_size is not None:
|
|
input_size = to_2tuple(input_size)
|
|
out_size = [size // stride_q for size in input_size]
|
|
self.init_out_size = out_size
|
|
else:
|
|
self.init_out_size = None
|
|
else:
|
|
self.pool_skip = None
|
|
self.init_out_size = input_size
|
|
|
|
def forward(self, x, in_size):
|
|
x_norm = self.norm1(x)
|
|
x_attn, out_size = self.attn(x_norm, in_size)
|
|
|
|
if self.dim_mul_in_attention and self.proj is not None:
|
|
skip = self.proj(x_norm)
|
|
else:
|
|
skip = x
|
|
|
|
if self.pool_skip is not None:
|
|
skip, _ = attention_pool(skip, self.pool_skip, in_size)
|
|
|
|
x = skip + self.drop_path(x_attn)
|
|
x_norm = self.norm2(x)
|
|
x_mlp = self.mlp(x_norm)
|
|
|
|
if not self.dim_mul_in_attention and self.proj is not None:
|
|
skip = self.proj(x_norm)
|
|
else:
|
|
skip = x
|
|
|
|
x = skip + self.drop_path(x_mlp)
|
|
|
|
return x, out_size
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class MViT(BaseBackbone):
|
|
"""Multi-scale ViT v2.
|
|
|
|
A PyTorch implement of : `MViTv2: Improved Multiscale Vision Transformers
|
|
for Classification and Detection <https://arxiv.org/abs/2112.01526>`_
|
|
|
|
Inspiration from `the official implementation
|
|
<https://github.com/facebookresearch/mvit>`_ and `the detectron2
|
|
implementation <https://github.com/facebookresearch/detectron2>`_
|
|
|
|
Args:
|
|
arch (str | dict): MViT architecture. If use string, choose
|
|
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
|
|
have below keys:
|
|
|
|
- **embed_dims** (int): The dimensions of embedding.
|
|
- **num_layers** (int): The number of layers.
|
|
- **num_heads** (int): The number of heads in attention
|
|
modules of the initial layer.
|
|
- **downscale_indices** (List[int]): The layer indices to downscale
|
|
the feature map.
|
|
|
|
Defaults to 'base'.
|
|
img_size (int): The expected input image shape. Defaults to 224.
|
|
in_channels (int): The num of input channels. Defaults to 3.
|
|
out_scales (int | Sequence[int]): The output scale indices.
|
|
They should not exceed the length of ``downscale_indices``.
|
|
Defaults to -1, which means the last scale.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
|
|
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
|
the patch embedding. Defaults to False.
|
|
interpolate_mode (str): Select the interpolate mode for absolute
|
|
position embedding vector resize. Defaults to "bicubic".
|
|
pool_kernel (tuple): kernel size for qkv pooling layers.
|
|
Defaults to (3, 3).
|
|
dim_mul (int): The magnification for ``embed_dims`` in the downscale
|
|
layers. Defaults to 2.
|
|
head_mul (int): The magnification for ``num_heads`` in the downscale
|
|
layers. Defaults to 2.
|
|
adaptive_kv_stride (int): The stride size for kv pooling in the initial
|
|
layer. Defaults to 4.
|
|
rel_pos_spatial (bool): Whether to enable the spatial relative position
|
|
embedding. Defaults to True.
|
|
residual_pooling (bool): Whether to enable the residual connection
|
|
after attention pooling. Defaults to True.
|
|
dim_mul_in_attention (bool): Whether to multiply the ``embed_dims`` in
|
|
attention layers. If False, multiply it in MLP layers.
|
|
Defaults to True.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative
|
|
positional parameters. Defaults to False.
|
|
mlp_ratio (float): Ratio of hidden dimensions in MLP layers.
|
|
Defaults to 4.0.
|
|
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
|
norm_cfg (dict): Config dict for normalization layer for all output
|
|
features. Defaults to ``dict(type='LN', eps=1e-6)``.
|
|
patch_cfg (dict): Config dict for the patch embedding layer.
|
|
Defaults to ``dict(kernel_size=7, stride=4, padding=3)``.
|
|
init_cfg (dict, optional): The Config for initialization.
|
|
Defaults to None.
|
|
|
|
Examples:
|
|
>>> import torch
|
|
>>> from mmpretrain.models import build_backbone
|
|
>>>
|
|
>>> cfg = dict(type='MViT', arch='tiny', out_scales=[0, 1, 2, 3])
|
|
>>> model = build_backbone(cfg)
|
|
>>> inputs = torch.rand(1, 3, 224, 224)
|
|
>>> outputs = model(inputs)
|
|
>>> for i, output in enumerate(outputs):
|
|
>>> print(f'scale{i}: {output.shape}')
|
|
scale0: torch.Size([1, 96, 56, 56])
|
|
scale1: torch.Size([1, 192, 28, 28])
|
|
scale2: torch.Size([1, 384, 14, 14])
|
|
scale3: torch.Size([1, 768, 7, 7])
|
|
"""
|
|
arch_zoo = {
|
|
'tiny': {
|
|
'embed_dims': 96,
|
|
'num_layers': 10,
|
|
'num_heads': 1,
|
|
'downscale_indices': [1, 3, 8]
|
|
},
|
|
'small': {
|
|
'embed_dims': 96,
|
|
'num_layers': 16,
|
|
'num_heads': 1,
|
|
'downscale_indices': [1, 3, 14]
|
|
},
|
|
'base': {
|
|
'embed_dims': 96,
|
|
'num_layers': 24,
|
|
'num_heads': 1,
|
|
'downscale_indices': [2, 5, 21]
|
|
},
|
|
'large': {
|
|
'embed_dims': 144,
|
|
'num_layers': 48,
|
|
'num_heads': 2,
|
|
'downscale_indices': [2, 8, 44]
|
|
},
|
|
}
|
|
num_extra_tokens = 0
|
|
|
|
def __init__(self,
|
|
arch='base',
|
|
img_size=224,
|
|
in_channels=3,
|
|
out_scales=-1,
|
|
drop_path_rate=0.,
|
|
use_abs_pos_embed=False,
|
|
interpolate_mode='bicubic',
|
|
pool_kernel=(3, 3),
|
|
dim_mul=2,
|
|
head_mul=2,
|
|
adaptive_kv_stride=4,
|
|
rel_pos_spatial=True,
|
|
residual_pooling=True,
|
|
dim_mul_in_attention=True,
|
|
rel_pos_zero_init=False,
|
|
mlp_ratio=4.,
|
|
qkv_bias=True,
|
|
norm_cfg=dict(type='LN', eps=1e-6),
|
|
patch_cfg=dict(kernel_size=7, stride=4, padding=3),
|
|
init_cfg=None):
|
|
super().__init__(init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
arch = arch.lower()
|
|
assert arch in set(self.arch_zoo), \
|
|
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
|
self.arch_settings = self.arch_zoo[arch]
|
|
else:
|
|
essential_keys = {
|
|
'embed_dims', 'num_layers', 'num_heads', 'downscale_indices'
|
|
}
|
|
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
|
f'Custom arch needs a dict with keys {essential_keys}'
|
|
self.arch_settings = arch
|
|
|
|
self.embed_dims = self.arch_settings['embed_dims']
|
|
self.num_layers = self.arch_settings['num_layers']
|
|
self.num_heads = self.arch_settings['num_heads']
|
|
self.downscale_indices = self.arch_settings['downscale_indices']
|
|
self.num_scales = len(self.downscale_indices) + 1
|
|
self.stage_indices = {
|
|
index - 1: i
|
|
for i, index in enumerate(self.downscale_indices)
|
|
}
|
|
self.stage_indices[self.num_layers - 1] = self.num_scales - 1
|
|
self.use_abs_pos_embed = use_abs_pos_embed
|
|
self.interpolate_mode = interpolate_mode
|
|
|
|
if isinstance(out_scales, int):
|
|
out_scales = [out_scales]
|
|
assert isinstance(out_scales, Sequence), \
|
|
f'"out_scales" must by a sequence or int, ' \
|
|
f'get {type(out_scales)} instead.'
|
|
for i, index in enumerate(out_scales):
|
|
if index < 0:
|
|
out_scales[i] = self.num_scales + index
|
|
assert 0 <= out_scales[i] <= self.num_scales, \
|
|
f'Invalid out_scales {index}'
|
|
self.out_scales = sorted(list(out_scales))
|
|
|
|
# Set patch embedding
|
|
_patch_cfg = dict(
|
|
in_channels=in_channels,
|
|
input_size=img_size,
|
|
embed_dims=self.embed_dims,
|
|
conv_type='Conv2d',
|
|
)
|
|
_patch_cfg.update(patch_cfg)
|
|
self.patch_embed = PatchEmbed(**_patch_cfg)
|
|
self.patch_resolution = self.patch_embed.init_out_size
|
|
|
|
# Set absolute position embedding
|
|
if self.use_abs_pos_embed:
|
|
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, num_patches, self.embed_dims))
|
|
|
|
# stochastic depth decay rule
|
|
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
|
|
|
self.blocks = ModuleList()
|
|
out_dims_list = [self.embed_dims]
|
|
num_heads = self.num_heads
|
|
stride_kv = adaptive_kv_stride
|
|
input_size = self.patch_resolution
|
|
for i in range(self.num_layers):
|
|
if i in self.downscale_indices:
|
|
num_heads *= head_mul
|
|
stride_q = 2
|
|
stride_kv = max(stride_kv // 2, 1)
|
|
else:
|
|
stride_q = 1
|
|
|
|
# Set output embed_dims
|
|
if dim_mul_in_attention and i in self.downscale_indices:
|
|
# multiply embed_dims in downscale layers.
|
|
out_dims = out_dims_list[-1] * dim_mul
|
|
elif not dim_mul_in_attention and i + 1 in self.downscale_indices:
|
|
# multiply embed_dims before downscale layers.
|
|
out_dims = out_dims_list[-1] * dim_mul
|
|
else:
|
|
out_dims = out_dims_list[-1]
|
|
|
|
attention_block = MultiScaleBlock(
|
|
in_dims=out_dims_list[-1],
|
|
out_dims=out_dims,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop_path=dpr[i],
|
|
norm_cfg=norm_cfg,
|
|
qkv_pool_kernel=pool_kernel,
|
|
stride_q=stride_q,
|
|
stride_kv=stride_kv,
|
|
rel_pos_spatial=rel_pos_spatial,
|
|
residual_pooling=residual_pooling,
|
|
dim_mul_in_attention=dim_mul_in_attention,
|
|
input_size=input_size,
|
|
rel_pos_zero_init=rel_pos_zero_init)
|
|
self.blocks.append(attention_block)
|
|
|
|
input_size = attention_block.init_out_size
|
|
out_dims_list.append(out_dims)
|
|
|
|
if i in self.stage_indices:
|
|
stage_index = self.stage_indices[i]
|
|
if stage_index in self.out_scales:
|
|
norm_layer = build_norm_layer(norm_cfg, out_dims)[1]
|
|
self.add_module(f'norm{stage_index}', norm_layer)
|
|
|
|
def init_weights(self):
|
|
super().init_weights()
|
|
|
|
if (isinstance(self.init_cfg, dict)
|
|
and self.init_cfg['type'] == 'Pretrained'):
|
|
# Suppress default init if use pretrained model.
|
|
return
|
|
|
|
if self.use_abs_pos_embed:
|
|
trunc_normal_(self.pos_embed, std=0.02)
|
|
|
|
def forward(self, x):
|
|
"""Forward the MViT."""
|
|
B = x.shape[0]
|
|
x, patch_resolution = self.patch_embed(x)
|
|
|
|
if self.use_abs_pos_embed:
|
|
x = x + resize_pos_embed(
|
|
self.pos_embed,
|
|
self.patch_resolution,
|
|
patch_resolution,
|
|
mode=self.interpolate_mode,
|
|
num_extra_tokens=self.num_extra_tokens)
|
|
|
|
outs = []
|
|
for i, block in enumerate(self.blocks):
|
|
x, patch_resolution = block(x, patch_resolution)
|
|
|
|
if i in self.stage_indices:
|
|
stage_index = self.stage_indices[i]
|
|
if stage_index in self.out_scales:
|
|
B, _, C = x.shape
|
|
x = getattr(self, f'norm{stage_index}')(x)
|
|
out = x.transpose(1, 2).reshape(B, C, *patch_resolution)
|
|
outs.append(out.contiguous())
|
|
|
|
return tuple(outs)
|