EasyCV/easycv/models/backbones/conv_mae_vit.py

464 lines
16 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
# Reference: https://github.com/Alpha-VL/FastConvMAE
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from easycv.models.registry import BACKBONES
from easycv.models.utils import DropPath
from easycv.models.utils.pos_embed import get_2d_sincos_pos_embed
from .vision_transformer import Block
class PatchEmbed(nn.Module):
""" Image to Patch Embedding.
Args:
img_size (int | tuple): The size of input image
patch_size (int | tiple): The size of one patch
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768):
super().__init__()
if not isinstance(img_size, (list, tuple)):
img_size = (img_size, img_size)
if not isinstance(patch_size, (list, tuple)):
patch_size = (patch_size, patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
self.act = nn.GELU()
def forward(self, x):
_, _, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[
1], f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return self.act(x)
class ConvMlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvBlock(nn.Module):
def __init__(self,
dim,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer=nn.GELU):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.conv1 = nn.Conv2d(dim, dim, 1)
self.conv2 = nn.Conv2d(dim, dim, 1)
self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ConvMlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x, mask=None):
if mask is not None:
residual = x
x = self.conv1(
self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2))
x1 = self.attn(mask[0] * x)
x2 = self.attn(mask[1] * x)
x3 = self.attn(mask[2] * x)
x4 = self.attn(mask[3] * x)
x = mask[0] * x1 + mask[1] * x2 + mask[2] * x3 + mask[3] * x4
x = residual + self.drop_path(self.conv2(x))
else:
x = x + self.drop_path(
self.conv2(
self.attn(
self.conv1(
self.norm1(x.permute(0, 2, 3, 1)).permute(
0, 3, 1, 2)))))
x = x + self.drop_path(
self.mlp(self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)))
return x
@BACKBONES.register_module
class FastConvMAEViT(nn.Module):
""" Fast ConvMAE framework is a superiorly fast masked modeling scheme via
complementary masking and mixture of reconstrunctors based on the ConvMAE(https://arxiv.org/abs/2205.03892).
Args:
img_size (list | tuple): Input image size for three stages.
patch_size (list | tuple): The patch size for three stages.
in_channels (int): The num of input channels. Default: 3
embed_dim (list | tuple): The dimensions of embedding for three stages.
depth (list | tuple): depth for three stages.
num_heads (int): Parallel attention heads
mlp_ratio (list | tuple): Mlp expansion ratio.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
norm_layer (nn.Module): normalization layer
init_pos_embed_by_sincos: initialize pos_embed by sincos strategy
with_fuse(bool): Whether to use fuse layers.
global_pool: global pool
"""
def __init__(
self,
img_size=[224, 56, 28],
patch_size=[4, 2, 2],
in_channels=3,
embed_dim=[256, 384, 768],
depth=[2, 2, 11],
num_heads=12,
mlp_ratio=[4, 4, 4],
drop_rate=0.,
drop_path_rate=0.1,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_pos_embed_by_sincos=True,
with_fuse=True,
global_pool=False,
):
super().__init__()
self.init_pos_embed_by_sincos = init_pos_embed_by_sincos
self.with_fuse = with_fuse
self.global_pool = global_pool
assert len(img_size) == len(patch_size) == len(embed_dim) == len(
mlp_ratio)
self.patch_size = patch_size[0] * patch_size[1] * patch_size[2]
self.patch_embed1 = PatchEmbed(
img_size=img_size[0],
patch_size=patch_size[0],
in_channels=in_channels,
embed_dim=embed_dim[0])
self.patch_embed2 = PatchEmbed(
img_size=img_size[1],
patch_size=patch_size[1],
in_channels=embed_dim[0],
embed_dim=embed_dim[1])
self.patch_embed3 = PatchEmbed(
img_size=img_size[2],
patch_size=patch_size[2],
in_channels=embed_dim[1],
embed_dim=embed_dim[2])
self.patch_embed4 = nn.Linear(embed_dim[2], embed_dim[2])
if with_fuse:
self._make_fuse_layers(embed_dim)
self.num_patches = self.patch_embed3.num_patches
if init_pos_embed_by_sincos:
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, embed_dim[2]),
requires_grad=False)
else:
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, embed_dim[2]), )
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList([
ConvBlock(
dim=embed_dim[0],
mlp_ratio=mlp_ratio[0],
drop=drop_rate,
drop_path=dpr[i],
) for i in range(depth[0])
])
self.blocks2 = nn.ModuleList([
ConvBlock(
dim=embed_dim[1],
mlp_ratio=mlp_ratio[1],
drop=drop_rate,
drop_path=dpr[depth[0] + i],
) for i in range(depth[1])
])
self.blocks3 = nn.ModuleList([
Block(
dim=embed_dim[2],
num_heads=num_heads,
mlp_ratio=mlp_ratio[2],
qkv_bias=True,
qk_scale=None,
drop=drop_rate,
drop_path=dpr[depth[0] + depth[1] + i],
norm_layer=norm_layer) for i in range(depth[2])
])
if self.global_pool:
self.fc_norm = norm_layer(embed_dim[-1])
self.norm = None
else:
self.norm = norm_layer(embed_dim[-1])
self.fc_norm = None
def init_weights(self):
if self.init_pos_embed_by_sincos:
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.num_patches**.5),
cls_token=False)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0))
else:
trunc_normal_(self.pos_embed, std=.02)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed3.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _make_fuse_layers(self, embed_dim):
self.stage1_output_decode = nn.Conv2d(
embed_dim[0], embed_dim[2], 4, stride=4)
self.stage2_output_decode = nn.Conv2d(
embed_dim[1], embed_dim[2], 2, stride=2)
def random_masking(self, x, mask_ratio=None):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N = x.shape[0]
L = self.num_patches
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep1 = ids_shuffle[:, :len_keep]
ids_keep2 = ids_shuffle[:, len_keep:2 * len_keep]
ids_keep3 = ids_shuffle[:, 2 * len_keep:3 * len_keep]
ids_keep4 = ids_shuffle[:, 3 * len_keep:]
# generate the binary mask: 0 is keep, 1 is remove
mask1 = torch.ones([N, L], device=x.device)
mask1[:, :len_keep] = 0
# unshuffle to get the binary mask
mask1 = torch.gather(mask1, dim=1, index=ids_restore)
mask2 = torch.ones([N, L], device=x.device)
mask2[:, len_keep:2 * len_keep] = 0
# unshuffle to get the binary mask
mask2 = torch.gather(mask2, dim=1, index=ids_restore)
mask3 = torch.ones([N, L], device=x.device)
mask3[:, 2 * len_keep:3 * len_keep] = 0
# unshuffle to get the binary mask
mask3 = torch.gather(mask3, dim=1, index=ids_restore)
mask4 = torch.ones([N, L], device=x.device)
mask4[:, 3 * len_keep:4 * len_keep] = 0
# unshuffle to get the binary mask
mask4 = torch.gather(mask4, dim=1, index=ids_restore)
return [ids_keep1, ids_keep2, ids_keep3,
ids_keep4], [mask1, mask2, mask3, mask4], ids_restore
def _fuse_forward(self, s1, s2, ids_keep=None, mask_ratio=None):
stage1_embed = self.stage1_output_decode(s1).flatten(2).permute(
0, 2, 1)
stage2_embed = self.stage2_output_decode(s2).flatten(2).permute(
0, 2, 1)
if mask_ratio is not None:
stage1_embed_1 = torch.gather(
stage1_embed,
dim=1,
index=ids_keep[0].unsqueeze(-1).repeat(1, 1,
stage1_embed.shape[-1]))
stage2_embed_1 = torch.gather(
stage2_embed,
dim=1,
index=ids_keep[0].unsqueeze(-1).repeat(1, 1,
stage2_embed.shape[-1]))
stage1_embed_2 = torch.gather(
stage1_embed,
dim=1,
index=ids_keep[1].unsqueeze(-1).repeat(1, 1,
stage1_embed.shape[-1]))
stage2_embed_2 = torch.gather(
stage2_embed,
dim=1,
index=ids_keep[1].unsqueeze(-1).repeat(1, 1,
stage2_embed.shape[-1]))
stage1_embed_3 = torch.gather(
stage1_embed,
dim=1,
index=ids_keep[2].unsqueeze(-1).repeat(1, 1,
stage1_embed.shape[-1]))
stage2_embed_3 = torch.gather(
stage2_embed,
dim=1,
index=ids_keep[2].unsqueeze(-1).repeat(1, 1,
stage2_embed.shape[-1]))
stage1_embed_4 = torch.gather(
stage1_embed,
dim=1,
index=ids_keep[3].unsqueeze(-1).repeat(1, 1,
stage1_embed.shape[-1]))
stage2_embed_4 = torch.gather(
stage2_embed,
dim=1,
index=ids_keep[3].unsqueeze(-1).repeat(1, 1,
stage2_embed.shape[-1]))
stage1_embed = torch.cat([
stage1_embed_1, stage1_embed_2, stage1_embed_3, stage1_embed_4
])
stage2_embed = torch.cat([
stage2_embed_1, stage2_embed_2, stage2_embed_3, stage2_embed_4
])
return stage1_embed, stage2_embed
def forward(self, x, mask_ratio=None):
if mask_ratio is not None:
assert self.with_fuse
# embed patches
if mask_ratio is not None:
ids_keep, masks, ids_restore = self.random_masking(x, mask_ratio)
mask_for_patch1 = [
1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(
1, 1, 1, 16).reshape(-1, 14, 14, 4, 4).permute(
0, 1, 3, 2, 4).reshape(x.shape[0], 56, 56).unsqueeze(1)
for mask in masks
]
mask_for_patch2 = [
1 - mask.reshape(-1, 14, 14).unsqueeze(-1).repeat(
1, 1, 1, 4).reshape(-1, 14, 14, 2, 2).permute(
0, 1, 3, 2, 4).reshape(x.shape[0], 28, 28).unsqueeze(1)
for mask in masks
]
else:
mask_for_patch1 = None
mask_for_patch2 = None
s1 = self.patch_embed1(x)
s1 = self.pos_drop(s1)
for blk in self.blocks1:
s1 = blk(s1, mask_for_patch1)
s2 = self.patch_embed2(s1)
for blk in self.blocks2:
s2 = blk(s2, mask_for_patch2)
if self.with_fuse:
stage1_embed, stage2_embed = self._fuse_forward(
s1, s2, ids_keep, mask_ratio)
x = self.patch_embed3(s2)
x = x.flatten(2).permute(0, 2, 1)
x = self.patch_embed4(x)
# add pos embed w/o cls token
x = x + self.pos_embed
if mask_ratio is not None:
x1 = torch.gather(
x,
dim=1,
index=ids_keep[0].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
x2 = torch.gather(
x,
dim=1,
index=ids_keep[1].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
x3 = torch.gather(
x,
dim=1,
index=ids_keep[2].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
x4 = torch.gather(
x,
dim=1,
index=ids_keep[3].unsqueeze(-1).repeat(1, 1, x.shape[-1]))
x = torch.cat([x1, x2, x3, x4])
# apply Transformer blocks
for blk in self.blocks3:
x = blk(x)
if self.with_fuse:
x = x + stage1_embed + stage2_embed
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
x = self.fc_norm(x)
else:
x = self.norm(x)
if mask_ratio is not None:
mask = torch.cat([masks[0], masks[1], masks[2], masks[3]])
return x, mask, ids_restore
return x, None, None