mirror of https://github.com/alibaba/EasyCV.git
135 lines
4.5 KiB
Python
135 lines
4.5 KiB
Python
"""
|
|
Mostly copy-paste from
|
|
https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
|
"""
|
|
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from timm.models.vision_transformer import Block, PatchEmbed
|
|
|
|
from easycv.models.utils import get_2d_sincos_pos_embed
|
|
from ..registry import BACKBONES
|
|
|
|
|
|
@BACKBONES.register_module
|
|
class MaskedAutoencoderViT(nn.Module):
|
|
""" Masked Autoencoder with VisionTransformer backbone.
|
|
MaskedAutoencoderViT is mostly same as vit_tranformer_dynamic, but with a random_masking func.
|
|
MaskedAutoencoderViT model can be loaded by vit_tranformer_dynamic.
|
|
|
|
Args:
|
|
img_size(int): input image size
|
|
patch_size(int): patch size
|
|
in_chans(int): input image channels
|
|
embed_dim(int): feature dimensions
|
|
depth(int): number of encoder layers
|
|
num_heads(int): Parallel attention heads
|
|
mlp_ratio(float): mlp ratio
|
|
norm_layer: type of normalization layer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
img_size=224,
|
|
patch_size=16,
|
|
in_chans=3,
|
|
embed_dim=1024,
|
|
depth=24,
|
|
num_heads=16,
|
|
mlp_ratio=4.,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
):
|
|
super().__init__()
|
|
|
|
self.patch_size = patch_size
|
|
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans,
|
|
embed_dim)
|
|
self.num_patches = self.patch_embed.num_patches
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
self.pos_embed = nn.Parameter(
|
|
torch.zeros(1, self.num_patches + 1, embed_dim),
|
|
requires_grad=False) # fixed sin-cos embedding
|
|
self.blocks = nn.ModuleList([
|
|
Block(
|
|
embed_dim,
|
|
num_heads,
|
|
mlp_ratio,
|
|
qkv_bias=True,
|
|
norm_layer=norm_layer) for i in range(depth)
|
|
])
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
def init_weights(self):
|
|
w = self.patch_embed.proj.weight.data
|
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
torch.nn.init.normal_(self.cls_token, std=.02)
|
|
pos_embed = get_2d_sincos_pos_embed(
|
|
self.pos_embed.shape[-1],
|
|
int(self.patch_embed.num_patches**.5),
|
|
cls_token=True)
|
|
self.pos_embed.data.copy_(
|
|
torch.from_numpy(pos_embed).float().unsqueeze(0))
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Linear):
|
|
# we use xavier_uniform following official JAX ViT:
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def random_masking(self, x, mask_ratio):
|
|
"""
|
|
Perform per-sample random masking by per-sample shuffling.
|
|
Per-sample shuffling is done by argsort random noise.
|
|
x: [N, L, D], sequence
|
|
"""
|
|
N, L, D = x.shape # batch, length, dim
|
|
len_keep = int(L * (1 - mask_ratio))
|
|
|
|
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
|
|
|
# sort noise for each sample
|
|
ids_shuffle = torch.argsort(
|
|
noise, dim=1) # ascend: small is keep, large is remove
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
|
|
|
# keep the first subset
|
|
ids_keep = ids_shuffle[:, :len_keep]
|
|
x_masked = torch.gather(
|
|
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
|
|
|
# generate the binary mask: 0 is keep, 1 is remove
|
|
mask = torch.ones([N, L], device=x.device)
|
|
mask[:, :len_keep] = 0
|
|
# unshuffle to get the binary mask
|
|
mask = torch.gather(mask, dim=1, index=ids_restore)
|
|
|
|
return x_masked, mask, ids_restore
|
|
|
|
def forward(self, x, mask_ratio):
|
|
# embed patches
|
|
x = self.patch_embed(x)
|
|
|
|
# add pos embed w/o cls token
|
|
x = x + self.pos_embed[:, 1:, :]
|
|
|
|
# masking: length -> length * mask_ratio
|
|
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
|
|
|
# append cls token
|
|
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
# apply Transformer blocks
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
x = self.norm(x)
|
|
|
|
return x, mask, ids_restore
|