EasyCV/easycv/models/backbones/mae_vit_transformer.py

135 lines
4.5 KiB
Python

"""
Mostly copy-paste from
https://github.com/facebookresearch/mae/blob/main/models_mae.py
"""
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import Block, PatchEmbed
from easycv.models.utils import get_2d_sincos_pos_embed
from ..registry import BACKBONES
@BACKBONES.register_module
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone.
MaskedAutoencoderViT is mostly same as vit_tranformer_dynamic, but with a random_masking func.
MaskedAutoencoderViT model can be loaded by vit_tranformer_dynamic.
Args:
img_size(int): input image size
patch_size(int): patch size
in_chans(int): input image channels
embed_dim(int): feature dimensions
depth(int): number of encoder layers
num_heads(int): Parallel attention heads
mlp_ratio(float): mlp ratio
norm_layer: type of normalization layer
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.patch_size = patch_size
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans,
embed_dim)
self.num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, embed_dim),
requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
def init_weights(self):
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=.02)
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**.5),
cls_token=True)
self.pos_embed.data.copy_(
torch.from_numpy(pos_embed).float().unsqueeze(0))
for m in self.modules():
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore