# Copyright (c) Alibaba, Inc. and its affiliates. # Reference: https://github.com/Alpha-VL/FastConvMAE from functools import partial import numpy as np import torch import torch.nn as nn from timm.models.layers import trunc_normal_ from easycv.models.registry import BACKBONES from easycv.models.utils import DropPath from easycv.models.utils.pos_embed import get_2d_sincos_pos_embed from .vision_transformer import Block class PatchEmbed(nn.Module): """ Image to Patch Embedding. Args: img_size (int | tuple): The size of input image patch_size (int | tiple): The size of one patch in_channels (int): The num of input channels. Default: 3 embed_dims (int): The dimensions of embedding. Default: 768 """ def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() if not isinstance(img_size, (list, tuple)): img_size = (img_size, img_size) if not isinstance(patch_size, (list, tuple)): patch_size = (patch_size, patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) self.act = nn.GELU() def forward(self, x): _, _, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[ 1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return self.act(x) class ConvMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class ConvBlock(nn.Module): def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU): super().__init__() self.norm1 = nn.LayerNorm(dim) self.conv1 = nn.Conv2d(dim, dim, 1) self.conv2 = nn.Conv2d(dim, dim, 1) self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) self.drop_path = DropPath( drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ConvMlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x, mask=None): if mask is not None: residual = x x = self.conv1( self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)) x1 = self.attn(mask[0] * x) x2 = self.attn(mask[1] * x) x3 = self.attn(mask[2] * x) x4 = self.attn(mask[3] * x) x = mask[0] * x1 + mask[1] * x2 + mask[2] * x3 + mask[3] * x4 x = residual + self.drop_path(self.conv2(x)) else: x = x + self.drop_path( self.conv2( self.attn( self.conv1( self.norm1(x.permute(0, 2, 3, 1)).permute( 0, 3, 1, 2))))) x = x + self.drop_path( self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))) return x @BACKBONES.register_module class FastConvMAEViT(nn.Module): """ Fast ConvMAE framework is a superiorly fast masked modeling scheme via complementary masking and mixture of reconstrunctors based on the ConvMAE(https://arxiv.org/abs/2205.03892). Args: img_size (list | tuple): Input image size for three stages. patch_size (list | tuple): The patch size for three stages. in_channels (int): The num of input channels. Default: 3 embed_dim (list | tuple): The dimensions of embedding for three stages. depth (list | tuple): depth for three stages. num_heads (int): Parallel attention heads mlp_ratio (list | tuple): Mlp expansion ratio. drop_rate (float): Probability of an element to be zeroed after the feed forward layer. Defaults to 0. drop_path_rate (float): Stochastic depth rate. Defaults to 0. norm_layer (nn.Module): normalization layer init_pos_embed_by_sincos: initialize pos_embed by sincos strategy with_fuse(bool): Whether to use fuse layers. global_pool: global pool """ def __init__( self, img_size=[224, 56, 28], patch_size=[4, 2, 2], in_channels=3, embed_dim=[256, 384, 768], depth=[2, 2, 11], num_heads=12, mlp_ratio=[4, 4, 4], drop_rate=0., drop_path_rate=0.1, norm_layer=partial(nn.LayerNorm, eps=1e-6), init_pos_embed_by_sincos=True, with_fuse=True, global_pool=False, ): super().__init__() self.init_pos_embed_by_sincos = init_pos_embed_by_sincos self.with_fuse = with_fuse self.global_pool = global_pool assert len(img_size) == len(patch_size) == len(embed_dim) == len( mlp_ratio) self.patch_size = patch_size[0] * patch_size[1] * patch_size[2] self.patch_embed1 = PatchEmbed( img_size=img_size[0], patch_size=patch_size[0], in_channels=in_channels, embed_dim=embed_dim[0]) self.patch_embed2 = PatchEmbed( img_size=img_size[1], patch_size=patch_size[1], in_channels=embed_dim[0], embed_dim=embed_dim[1]) self.patch_embed3 = PatchEmbed( img_size=img_size[2], patch_size=patch_size[2], in_channels=embed_dim[1], embed_dim=embed_dim[2]) self.patch_embed4 = nn.Linear(embed_dim[2], embed_dim[2]) if with_fuse: self._make_fuse_layers(embed_dim) self.num_patches = self.patch_embed3.num_patches if init_pos_embed_by_sincos: self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches, embed_dim[2]), requires_grad=False) else: self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches, embed_dim[2]), ) self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth decay rule dpr = np.linspace(0, drop_path_rate, sum(depth)) self.blocks1 = nn.ModuleList([ ConvBlock( dim=embed_dim[0], mlp_ratio=mlp_ratio[0], drop=drop_rate, drop_path=dpr[i], ) for i in range(depth[0]) ]) self.blocks2 = nn.ModuleList([ ConvBlock( dim=embed_dim[1], mlp_ratio=mlp_ratio[1], drop=drop_rate, drop_path=dpr[depth[0] + i], ) for i in range(depth[1]) ]) self.blocks3 = nn.ModuleList([ Block( dim=embed_dim[2], num_heads=num_heads, mlp_ratio=mlp_ratio[2], qkv_bias=True, qk_scale=None, drop=drop_rate, drop_path=dpr[depth[0] + depth[1] + i], norm_layer=norm_layer) for i in range(depth[2]) ]) if self.global_pool: self.fc_norm = norm_layer(embed_dim[-1]) self.norm = None else: self.norm = norm_layer(embed_dim[-1]) self.fc_norm = None def init_weights(self): if self.init_pos_embed_by_sincos: # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed( self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=False) self.pos_embed.data.copy_( torch.from_numpy(pos_embed).float().unsqueeze(0)) else: trunc_normal_(self.pos_embed, std=.02) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed3.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _make_fuse_layers(self, embed_dim): self.stage1_output_decode = nn.Conv2d( embed_dim[0], embed_dim[2], 4, stride=4) self.stage2_output_decode = nn.Conv2d( embed_dim[1], embed_dim[2], 2, stride=2) def random_masking(self, x, mask_ratio=None): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N = x.shape[0] L = self.num_patches len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort( noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep1 = ids_shuffle[:, :len_keep] ids_keep2 = ids_shuffle[:, len_keep:2 * len_keep] ids_keep3 = ids_shuffle[:, 2 * len_keep:3 * len_keep] ids_keep4 = ids_shuffle[:, 3 * len_keep:] # generate the binary mask: 0 is keep, 1 is remove mask1 = torch.ones([N, L], device=x.device) mask1[:, :len_keep] = 0 # unshuffle to get the binary mask mask1 = torch.gather(mask1, dim=1, index=ids_restore) mask2 = torch.ones([N, L], device=x.device) mask2[:, len_keep:2 * len_keep] = 0 # unshuffle to get the binary mask mask2 = torch.gather(mask2, dim=1, index=ids_restore) mask3 = torch.ones([N, L], device=x.device) mask3[:, 2 * len_keep:3 * len_keep] = 0 # unshuffle to get the binary mask mask3 = torch.gather(mask3, dim=1, index=ids_restore) mask4 = torch.ones([N, L], device=x.device) mask4[:, 3 * len_keep:4 * len_keep] = 0 # unshuffle to get the binary mask mask4 = torch.gather(mask4, dim=1, index=ids_restore) return [ids_keep1, ids_keep2, ids_keep3, ids_keep4], [mask1, mask2, mask3, mask4], ids_restore def _fuse_forward(self, s1, s2, ids_keep=None, mask_ratio=None): stage1_embed = self.stage1_output_decode(s1).flatten(2).permute( 0, 2, 1) stage2_embed = self.stage2_output_decode(s2).flatten(2).permute( 0, 2, 1) if mask_ratio is not None: stage1_embed_1 = torch.gather( stage1_embed, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_1 = torch.gather( stage2_embed, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_2 = torch.gather( stage1_embed, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_2 = torch.gather( stage2_embed, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_3 = torch.gather( stage1_embed, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_3 = torch.gather( stage2_embed, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed_4 = torch.gather( stage1_embed, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage1_embed.shape[-1])) stage2_embed_4 = torch.gather( stage2_embed, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, stage2_embed.shape[-1])) stage1_embed = torch.cat([ stage1_embed_1, stage1_embed_2, stage1_embed_3, stage1_embed_4 ]) stage2_embed = torch.cat([ stage2_embed_1, stage2_embed_2, stage2_embed_3, stage2_embed_4 ]) return stage1_embed, stage2_embed def forward(self, x, mask_ratio=None): if mask_ratio is not None: assert self.with_fuse # embed patches if mask_ratio is not None: ids_keep, masks, ids_restore = self.random_masking(x, mask_ratio) mask_for_patch1 = [ 1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat( 1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute( 0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1) for mask in masks ] mask_for_patch2 = [ 1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat( 1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute( 0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1) for mask in masks ] else: mask_for_patch1 = None mask_for_patch2 = None s1 = self.patch_embed1(x) s1 = self.pos_drop(s1) for blk in self.blocks1: s1 = blk(s1, mask_for_patch1) s2 = self.patch_embed2(s1) for blk in self.blocks2: s2 = blk(s2, mask_for_patch2) if self.with_fuse: stage1_embed, stage2_embed = self._fuse_forward( s1, s2, ids_keep, mask_ratio) x = self.patch_embed3(s2) x = x.flatten(2).permute(0, 2, 1) x = self.patch_embed4(x) # add pos embed w/o cls token x = x + self.pos_embed if mask_ratio is not None: x1 = torch.gather( x, dim=1, index=ids_keep[0].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x2 = torch.gather( x, dim=1, index=ids_keep[1].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x3 = torch.gather( x, dim=1, index=ids_keep[2].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x4 = torch.gather( x, dim=1, index=ids_keep[3].unsqueeze(-1).repeat(1, 1, x.shape[-1])) x = torch.cat([x1, x2, x3, x4]) # apply Transformer blocks for blk in self.blocks3: x = blk(x) if self.with_fuse: x = x + stage1_embed + stage2_embed if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token x = self.fc_norm(x) else: x = self.norm(x) if mask_ratio is not None: mask = torch.cat([masks[0], masks[1], masks[2], masks[3]]) return x, mask, ids_restore return x, None, None