223 lines
8.4 KiB
Python
223 lines
8.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from ..backbones.vision_transformer import TransformerEncoderLayer
|
|
from ..utils import PromptMultiheadAttention
|
|
from .mae_neck import MAEPretrainDecoder
|
|
|
|
|
|
class PromptTransformerEncoderLayer(TransformerEncoderLayer):
|
|
"""Prompt Transformer Encoder Layer for MILAN.
|
|
|
|
This module is specific for the prompt encoder in MILAN. It will not update
|
|
the visible tokens from the encoder.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension.
|
|
num_heads (int): Parallel attention heads.
|
|
feedforward_channels (int): The hidden dimension for FFNs.
|
|
drop_rate (float): Probability of an element to be zeroed
|
|
after the feed forward layer. Defaults to 0.0.
|
|
attn_drop_rate (float): The drop out rate for attention layer.
|
|
Defaults to 0.0.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
|
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
|
Defaults to 2.
|
|
qkv_bias (bool): Enable bias for qkv if True. Defaults to True.
|
|
act_cfg (dict): The activation config for FFNs.
|
|
Defaults to ``dict(type='GELU')``.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
batch_first (bool): Key, Query and Value are shape of
|
|
(batch, n, embed_dim)
|
|
or (n, batch, embed_dim). Defaults to False.
|
|
init_cfg (dict, optional): The Config for initialization.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims: int,
|
|
num_heads: int,
|
|
feedforward_channels=int,
|
|
drop_rate: float = 0.,
|
|
attn_drop_rate: float = 0.,
|
|
drop_path_rate: float = 0.,
|
|
num_fcs: int = 2,
|
|
qkv_bias: bool = True,
|
|
act_cfg: dict = dict(type='GELU'),
|
|
norm_cfg: dict = dict(type='LN'),
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
embed_dims=embed_dims,
|
|
num_heads=num_heads,
|
|
feedforward_channels=feedforward_channels,
|
|
drop_rate=drop_rate,
|
|
attn_drop_rate=attn_drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
num_fcs=num_fcs,
|
|
qkv_bias=qkv_bias,
|
|
act_cfg=act_cfg,
|
|
norm_cfg=norm_cfg,
|
|
init_cfg=init_cfg)
|
|
self.attn = PromptMultiheadAttention(
|
|
embed_dims=embed_dims,
|
|
num_heads=num_heads,
|
|
attn_drop=attn_drop_rate,
|
|
proj_drop=drop_rate,
|
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
|
qkv_bias=qkv_bias)
|
|
|
|
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
|
|
ids_restore: torch.Tensor) -> torch.Tensor:
|
|
"""Forward function for `PromptMultiheadAttention`.
|
|
|
|
Args:
|
|
x (torch.Tensor): Mask token features with shape N x L_m x C.
|
|
visible_tokens (torch.Tensor): The visible tokens features from
|
|
encoder with shape N x L_v x C.
|
|
ids_restore (torch.Tensor): The ids of all tokens in the original
|
|
image with shape N x L.
|
|
|
|
Returns:
|
|
torch Tensor: Output features with shape N x L x C.
|
|
"""
|
|
x = x + self.attn(self.norm1(x), visible_tokens, ids_restore)
|
|
x = self.ffn(self.norm2(x), identity=x)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MILANPretrainDecoder(MAEPretrainDecoder):
|
|
"""Prompt decoder for MILAN.
|
|
|
|
This decoder is used in MILAN pretraining, which will not update these
|
|
visible tokens from the encoder.
|
|
|
|
Args:
|
|
num_patches (int): The number of total patches. Defaults to 196.
|
|
patch_size (int): Image patch size. Defaults to 16.
|
|
in_chans (int): The channel of input image. Defaults to 3.
|
|
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
|
|
decoder_embed_dim (int): Decoder's embedding dimension.
|
|
Defaults to 512.
|
|
decoder_depth (int): The depth of decoder. Defaults to 8.
|
|
decoder_num_heads (int): Number of attention heads of decoder.
|
|
Defaults to 16.
|
|
predict_feature_dim (int): The dimension of the feature to be
|
|
predicted. Defaults to 512.
|
|
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
|
Defaults to 4.
|
|
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
|
init_cfg (Union[List[dict], dict], optional): Initialization config
|
|
dict. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_patches: int = 196,
|
|
patch_size: int = 16,
|
|
in_chans: int = 3,
|
|
embed_dim: int = 1024,
|
|
decoder_embed_dim: int = 512,
|
|
decoder_depth: int = 8,
|
|
decoder_num_heads: int = 16,
|
|
predict_feature_dim: int = 512,
|
|
mlp_ratio: int = 4,
|
|
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
num_patches=num_patches,
|
|
patch_size=patch_size,
|
|
in_chans=in_chans,
|
|
embed_dim=embed_dim,
|
|
decoder_embed_dim=decoder_embed_dim,
|
|
decoder_depth=decoder_depth,
|
|
decoder_num_heads=decoder_num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
norm_cfg=norm_cfg,
|
|
init_cfg=init_cfg)
|
|
|
|
# map the dim of features from decoder to the dim compatible with
|
|
# that of CLIP
|
|
self.decoder_pred = nn.Linear(
|
|
decoder_embed_dim, predict_feature_dim, bias=True)
|
|
|
|
# use prompt transformer encoder layer, instead of the conventional
|
|
# transformer encoder layer
|
|
self.decoder_blocks = nn.ModuleList([
|
|
PromptTransformerEncoderLayer(
|
|
decoder_embed_dim,
|
|
decoder_num_heads,
|
|
int(mlp_ratio * decoder_embed_dim),
|
|
qkv_bias=True,
|
|
norm_cfg=norm_cfg) for _ in range(decoder_depth)
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor, ids_restore: torch.Tensor,
|
|
ids_keep: torch.Tensor,
|
|
ids_dump: torch.Tensor) -> torch.Tensor:
|
|
"""Forward function.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input features, which is of shape (N, L, C).
|
|
ids_restore (torch.Tensor): The indices to restore these tokens
|
|
to the original image.
|
|
ids_keep (torch.Tensor): The indices of tokens to be kept.
|
|
ids_dump (torch.Tensor): The indices of tokens to be masked.
|
|
|
|
Returns:
|
|
torch.Tensor: The reconstructed features, which is of shape
|
|
(N, L, C).
|
|
"""
|
|
# embed tokens
|
|
x = self.decoder_embed(x)
|
|
|
|
# append mask tokens to sequence
|
|
mask_tokens = self.mask_token.repeat(
|
|
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
|
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
|
|
x_ = torch.gather(
|
|
x_,
|
|
dim=1,
|
|
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
|
|
x = torch.cat([x[:, :1, :], x_], dim=1)
|
|
|
|
# add pos embed
|
|
x = x + self.decoder_pos_embed
|
|
|
|
# split mask tokens and visible tokens
|
|
visible_tokens = torch.cat([
|
|
x[:, :1, :],
|
|
torch.gather(
|
|
x[:, 1:, :],
|
|
dim=1,
|
|
index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
|
],
|
|
dim=1)
|
|
x = torch.gather(
|
|
x[:, 1:, :],
|
|
dim=1,
|
|
index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
|
|
|
for blk in self.decoder_blocks:
|
|
x = blk(x, visible_tokens, ids_restore)
|
|
|
|
# full sequence recovery
|
|
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
|
|
x_ = torch.gather(
|
|
x_,
|
|
dim=1,
|
|
index=ids_restore.unsqueeze(-1).repeat(1, 1,
|
|
x.shape[-1])) # unshuffle
|
|
x = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
|
|
|
|
x = self.decoder_norm(x)
|
|
|
|
# predictor projection
|
|
x = self.decoder_pred(x)
|
|
|
|
return x
|