389 lines
14 KiB
Python
389 lines
14 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_norm_layer
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.models.backbones.hivit import BlockWithRPE
|
|
from mmpretrain.registry import MODELS
|
|
from ..backbones.vision_transformer import TransformerEncoderLayer
|
|
from ..utils import build_2d_sincos_position_embedding
|
|
|
|
|
|
class PatchSplit(nn.Module):
|
|
"""The up-sample module used in neck (transformer pyramid network)
|
|
|
|
Args:
|
|
dim (int): the input dimension (channel number).
|
|
fpn_dim (int): the fpn dimension (channel number).
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
"""
|
|
|
|
def __init__(self, dim, fpn_dim, norm_cfg):
|
|
super().__init__()
|
|
_, self.norm = build_norm_layer(norm_cfg, dim)
|
|
self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False)
|
|
self.fpn_dim = fpn_dim
|
|
|
|
def forward(self, x):
|
|
B, N, H, W, C = x.shape
|
|
x = self.norm(x)
|
|
x = self.reduction(x)
|
|
x = x.reshape(B, N, H, W, 2, 2,
|
|
self.fpn_dim).permute(0, 1, 2, 4, 3, 5,
|
|
6).reshape(B, N, 2 * H, 2 * W,
|
|
self.fpn_dim)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class iTPNPretrainDecoder(BaseModule):
|
|
"""The neck module of iTPN (transformer pyramid network).
|
|
|
|
Args:
|
|
num_patches (int): The number of total patches. Defaults to 196.
|
|
patch_size (int): Image patch size. Defaults to 16.
|
|
in_chans (int): The channel of input image. Defaults to 3.
|
|
embed_dim (int): Encoder's embedding dimension. Defaults to 512.
|
|
fpn_dim (int): The fpn dimension (channel number).
|
|
fpn_depth (int): The layer number of feature pyramid.
|
|
decoder_embed_dim (int): Decoder's embedding dimension.
|
|
Defaults to 512.
|
|
decoder_depth (int): The depth of decoder. Defaults to 8.
|
|
decoder_num_heads (int): Number of attention heads of decoder.
|
|
Defaults to 16.
|
|
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
|
Defaults to 4.
|
|
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
|
reconstruction_type (str): The itpn supports 2 kinds of supervisions.
|
|
Defaults to 'pixel'.
|
|
num_outs (int): The output number of neck (transformer pyramid
|
|
network). Defaults to 3.
|
|
predict_feature_dim (int): The output dimension to supervision.
|
|
Defaults to None.
|
|
init_cfg (Union[List[dict], dict], optional): Initialization config
|
|
dict. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_patches: int = 196,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 512,
|
|
fpn_dim: int = 256,
|
|
fpn_depth: int = 2,
|
|
decoder_embed_dim: int = 512,
|
|
decoder_depth: int = 6,
|
|
decoder_num_heads: int = 16,
|
|
mlp_ratio: int = 4,
|
|
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
|
reconstruction_type: str = 'pixel',
|
|
num_outs: int = 3,
|
|
qkv_bias: bool = True,
|
|
qk_scale: Optional[bool] = None,
|
|
drop_rate: float = 0.0,
|
|
attn_drop_rate: float = 0.0,
|
|
predict_feature_dim: Optional[float] = None,
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.num_patches = num_patches
|
|
assert reconstruction_type in ['pixel', 'clip'], \
|
|
'iTPN method only support `pixel` and `clip`, ' \
|
|
f'but got `{reconstruction_type}`.'
|
|
self.reconstruction_type = reconstruction_type
|
|
self.num_outs = num_outs
|
|
|
|
self.build_transformer_pyramid(
|
|
num_outs=num_outs,
|
|
embed_dim=embed_dim,
|
|
fpn_dim=fpn_dim,
|
|
fpn_depth=fpn_depth,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
drop_rate=drop_rate,
|
|
attn_drop_rate=attn_drop_rate,
|
|
rpe=False,
|
|
norm_cfg=norm_cfg,
|
|
)
|
|
|
|
# merge the output
|
|
self.decoder_embed = nn.ModuleList()
|
|
self.decoder_embed.append(
|
|
nn.Sequential(
|
|
nn.LayerNorm(fpn_dim),
|
|
nn.Linear(fpn_dim, decoder_embed_dim, bias=True),
|
|
))
|
|
|
|
if self.num_outs >= 2:
|
|
self.decoder_embed.append(
|
|
nn.Sequential(
|
|
nn.LayerNorm(fpn_dim),
|
|
nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True),
|
|
))
|
|
if self.num_outs >= 3:
|
|
self.decoder_embed.append(
|
|
nn.Sequential(
|
|
nn.LayerNorm(fpn_dim),
|
|
nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True),
|
|
))
|
|
|
|
if reconstruction_type == 'pixel':
|
|
self.mask_token = nn.Parameter(
|
|
torch.zeros(1, 1, decoder_embed_dim))
|
|
|
|
# create new position embedding, different from that in encoder
|
|
# and is not learnable
|
|
self.decoder_pos_embed = nn.Parameter(
|
|
torch.zeros(1, self.num_patches, decoder_embed_dim),
|
|
requires_grad=False)
|
|
|
|
self.decoder_blocks = nn.ModuleList([
|
|
TransformerEncoderLayer(
|
|
decoder_embed_dim,
|
|
decoder_num_heads,
|
|
int(mlp_ratio * decoder_embed_dim),
|
|
qkv_bias=True,
|
|
norm_cfg=norm_cfg) for _ in range(decoder_depth)
|
|
])
|
|
|
|
self.decoder_norm_name, decoder_norm = build_norm_layer(
|
|
norm_cfg, decoder_embed_dim, postfix=1)
|
|
self.add_module(self.decoder_norm_name, decoder_norm)
|
|
|
|
# Used to map features to pixels
|
|
if predict_feature_dim is None:
|
|
predict_feature_dim = patch_size**2 * in_chans
|
|
self.decoder_pred = nn.Linear(
|
|
decoder_embed_dim, predict_feature_dim, bias=True)
|
|
else:
|
|
_, norm = build_norm_layer(norm_cfg, embed_dim)
|
|
self.add_module('norm', norm)
|
|
|
|
def build_transformer_pyramid(self,
|
|
num_outs=3,
|
|
embed_dim=512,
|
|
fpn_dim=256,
|
|
fpn_depth=2,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=True,
|
|
qk_scale=None,
|
|
drop_rate=0.0,
|
|
attn_drop_rate=0.0,
|
|
rpe=False,
|
|
norm_cfg=None):
|
|
Hp = None
|
|
mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim}
|
|
if num_outs > 1:
|
|
if embed_dim != fpn_dim:
|
|
self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim)
|
|
else:
|
|
self.align_dim_16tofpn = None
|
|
self.fpn_modules = nn.ModuleList()
|
|
self.fpn_modules.append(
|
|
BlockWithRPE(
|
|
Hp,
|
|
fpn_dim,
|
|
0,
|
|
mlp_ratio,
|
|
qkv_bias,
|
|
qk_scale,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=0.,
|
|
rpe=rpe,
|
|
norm_cfg=norm_cfg))
|
|
self.fpn_modules.append(
|
|
BlockWithRPE(
|
|
Hp,
|
|
fpn_dim,
|
|
0,
|
|
mlp_ratio,
|
|
qkv_bias,
|
|
qk_scale,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=0.,
|
|
rpe=False,
|
|
norm_cfg=norm_cfg,
|
|
))
|
|
|
|
self.align_dim_16to8 = nn.Linear(
|
|
mlvl_dims['8'], fpn_dim, bias=False)
|
|
self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg)
|
|
self.block_16to8 = nn.Sequential(*[
|
|
BlockWithRPE(
|
|
Hp,
|
|
fpn_dim,
|
|
0,
|
|
mlp_ratio,
|
|
qkv_bias,
|
|
qk_scale,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=0.,
|
|
rpe=rpe,
|
|
norm_cfg=norm_cfg,
|
|
) for _ in range(fpn_depth)
|
|
])
|
|
|
|
if num_outs > 2:
|
|
self.align_dim_8to4 = nn.Linear(
|
|
mlvl_dims['4'], fpn_dim, bias=False)
|
|
self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg)
|
|
self.block_8to4 = nn.Sequential(*[
|
|
BlockWithRPE(
|
|
Hp,
|
|
fpn_dim,
|
|
0,
|
|
mlp_ratio,
|
|
qkv_bias,
|
|
qk_scale,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=0.,
|
|
rpe=rpe,
|
|
norm_cfg=norm_cfg,
|
|
) for _ in range(fpn_depth)
|
|
])
|
|
self.fpn_modules.append(
|
|
BlockWithRPE(
|
|
Hp,
|
|
fpn_dim,
|
|
0,
|
|
mlp_ratio,
|
|
qkv_bias,
|
|
qk_scale,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=0.,
|
|
rpe=rpe,
|
|
norm_cfg=norm_cfg))
|
|
|
|
def init_weights(self) -> None:
|
|
"""Initialize position embedding and mask token of MAE decoder."""
|
|
super().init_weights()
|
|
|
|
if self.reconstruction_type == 'pixel':
|
|
decoder_pos_embed = build_2d_sincos_position_embedding(
|
|
int(self.num_patches**.5),
|
|
self.decoder_pos_embed.shape[-1],
|
|
cls_token=False)
|
|
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
|
|
|
|
torch.nn.init.normal_(self.mask_token, std=.02)
|
|
else:
|
|
self.rescale_init_weight()
|
|
|
|
def rescale_init_weight(self) -> None:
|
|
"""Rescale the initialized weights."""
|
|
|
|
def rescale(param, layer_id):
|
|
param.div_(math.sqrt(2.0 * layer_id))
|
|
|
|
for layer_id, layer in enumerate(self.fpn_modules):
|
|
if isinstance(layer, BlockWithRPE):
|
|
if layer.attn is not None:
|
|
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
|
|
|
@property
|
|
def decoder_norm(self):
|
|
"""The normalization layer of decoder."""
|
|
return getattr(self, self.decoder_norm_name)
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
ids_restore: torch.Tensor = None) -> torch.Tensor:
|
|
"""The forward function.
|
|
|
|
The process computes the visible patches' features vectors and the mask
|
|
tokens to output feature vectors, which will be used for
|
|
reconstruction.
|
|
|
|
Args:
|
|
x (torch.Tensor): hidden features, which is of shape
|
|
B x (L * mask_ratio) x C.
|
|
ids_restore (torch.Tensor): ids to restore original image.
|
|
|
|
Returns:
|
|
torch.Tensor: The reconstructed feature vectors, which is of
|
|
shape B x (num_patches) x C.
|
|
"""
|
|
|
|
features = x[:2]
|
|
x = x[-1]
|
|
B, L, _ = x.shape
|
|
x = x[..., None, None, :]
|
|
Hp = Wp = math.sqrt(L)
|
|
|
|
outs = [x] if self.align_dim_16tofpn is None else [
|
|
self.align_dim_16tofpn(x)
|
|
]
|
|
if self.num_outs >= 2:
|
|
x = self.block_16to8(
|
|
self.split_16to8(x) + self.align_dim_16to8(features[1]))
|
|
outs.append(x)
|
|
if self.num_outs >= 3:
|
|
x = self.block_8to4(
|
|
self.split_8to4(x) + self.align_dim_8to4(features[0]))
|
|
outs.append(x)
|
|
if self.num_outs > 3:
|
|
outs = [
|
|
out.reshape(B, Hp, Wp, *out.shape[-3:]).permute(
|
|
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3],
|
|
Wp * out.shape[-2]).contiguous()
|
|
for out in outs
|
|
]
|
|
if self.num_outs >= 4:
|
|
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
|
|
if self.num_outs >= 5:
|
|
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
|
|
|
|
for i, out in enumerate(outs):
|
|
out = self.fpn_modules[i](out)
|
|
outs[i] = out
|
|
|
|
if self.reconstruction_type == 'pixel':
|
|
feats = []
|
|
for feat, layer in zip(outs, self.decoder_embed):
|
|
x = layer(feat).reshape(B, L, -1)
|
|
# append mask tokens to sequence
|
|
mask_tokens = self.mask_token.repeat(
|
|
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
|
x = torch.cat([x, mask_tokens], dim=1)
|
|
x = torch.gather(
|
|
x,
|
|
dim=1,
|
|
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
|
|
feats.append(x)
|
|
x = feats.pop(0)
|
|
# add pos embed
|
|
x = x + self.decoder_pos_embed
|
|
|
|
for i, feat in enumerate(feats):
|
|
x = x + feats[i]
|
|
# apply Transformer blocks
|
|
for i, blk in enumerate(self.decoder_blocks):
|
|
x = blk(x)
|
|
x = self.decoder_norm(x)
|
|
x = self.decoder_pred(x)
|
|
return x
|
|
else:
|
|
feats = []
|
|
for feat, layer in zip(outs, self.decoder_embed):
|
|
x = layer(feat).reshape(B, L, -1)
|
|
feats.append(x)
|
|
x = feats.pop(0)
|
|
for i, feat in enumerate(feats):
|
|
x = x + feats[i]
|
|
|
|
x = self.norm(x)
|
|
|
|
return x
|