264 lines
9.9 KiB
Python
264 lines
9.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import random
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from mmpretrain.models.backbones import MixMIMTransformer
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
from ..utils import build_2d_sincos_position_embedding
|
|
from .base import BaseSelfSupervisor
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MixMIMPretrainTransformer(MixMIMTransformer):
|
|
"""MixMIM backbone for MixMIM pre-training.
|
|
|
|
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
|
|
Modeling for Efficient Visual Representation Learning
|
|
<https://arxiv.org/abs/2205.13137>`_
|
|
|
|
Args:
|
|
arch (str | dict): MixMIM architecture. If use string,
|
|
choose from 'base','large' and 'huge'.
|
|
If use dict, it should have below keys:
|
|
|
|
- **embed_dims** (int): The dimensions of embedding.
|
|
- **depths** (int): The number of transformer encoder layers.
|
|
- **num_heads** (int): The number of heads in attention modules.
|
|
|
|
Defaults to 'base'.
|
|
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
|
|
img_size (int | tuple): The expected input image shape. Because we
|
|
support dynamic input shape, just set the argument to mlp_ratio
|
|
the most common input image shape. Defaults to 224.
|
|
patch_size (int | tuple): The patch size in patch embedding.
|
|
Defaults to 16.
|
|
in_channels (int): The num of input channels. Defaults to 3.
|
|
window_size (list): The height and width of the window.
|
|
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
|
Defaults to True.
|
|
patch_cfg (dict): Extra config dict for patch embedding.
|
|
Defaults to an empty dict.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Defaults to 0.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
|
attn_drop_rate (float): Attention drop rate. Defaults to 0.
|
|
use_checkpoint (bool): Whether use the checkpoint to reduce GPU memory
|
|
cost. Defaults to False.
|
|
mask_ratio (bool): The base ratio of total number of patches to be
|
|
masked. Defaults to 0.5.
|
|
range_mask_ratio (float): The range of mask ratio.
|
|
Defaults to 0.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
arch: Union[str, dict] = 'base',
|
|
mlp_ratio: float = 4,
|
|
img_size: int = 224,
|
|
patch_size: int = 4,
|
|
in_channels: int = 3,
|
|
window_size: List = [14, 14, 14, 7],
|
|
qkv_bias: bool = True,
|
|
patch_cfg: dict = dict(),
|
|
norm_cfg: dict = dict(type='LN'),
|
|
drop_rate: float = 0.0,
|
|
drop_path_rate: float = 0.0,
|
|
attn_drop_rate: float = 0.0,
|
|
use_checkpoint: bool = False,
|
|
mask_ratio: float = 0.5,
|
|
range_mask_ratio: float = 0.0,
|
|
init_cfg: Optional[dict] = None) -> None:
|
|
|
|
super().__init__(
|
|
arch=arch,
|
|
mlp_ratio=mlp_ratio,
|
|
img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_channels=in_channels,
|
|
window_size=window_size,
|
|
qkv_bias=qkv_bias,
|
|
patch_cfg=patch_cfg,
|
|
norm_cfg=norm_cfg,
|
|
drop_rate=drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
attn_drop_rate=attn_drop_rate,
|
|
use_checkpoint=use_checkpoint,
|
|
init_cfg=init_cfg)
|
|
|
|
self.mask_ratio = mask_ratio
|
|
self.range_mask_ratio = range_mask_ratio
|
|
|
|
def init_weights(self):
|
|
"""Initialize position embedding, patch embedding."""
|
|
super(MixMIMTransformer, self).init_weights()
|
|
|
|
pos_embed = build_2d_sincos_position_embedding(
|
|
int(self.num_patches**.5),
|
|
self.absolute_pos_embed.shape[-1],
|
|
cls_token=False)
|
|
self.absolute_pos_embed.data.copy_(pos_embed.float())
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
# we use xavier_uniform following official JAX ViT:
|
|
torch.nn.init.xavier_uniform_(m.weight)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
nn.init.constant_(m.bias, 0)
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
def random_masking(self,
|
|
x: torch.Tensor,
|
|
mask_ratio: float = 0.5) -> Tuple[torch.Tensor]:
|
|
"""Generate the mask for MixMIM Pretraining.
|
|
|
|
Args:
|
|
x (torch.Tensor): Image with data augmentation applied, which is
|
|
of shape B x L x C.
|
|
mask_ratio (float): The mask ratio of total patches.
|
|
Defaults to 0.5.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
- mask_s1 (torch.Tensor): mask with stride of
|
|
self.encoder_stride // 8.
|
|
- mask_s2 (torch.Tensor): mask with stride of
|
|
self.encoder_stride // 4.
|
|
- mask_s3 (torch.Tensor): mask with stride of
|
|
self.encoder_stride // 2.
|
|
- mask (torch.Tensor): mask with stride of
|
|
self.encoder_stride.
|
|
"""
|
|
|
|
B, C, H, W = x.shape
|
|
out_H = H // self.encoder_stride
|
|
out_W = W // self.encoder_stride
|
|
s3_H, s3_W = out_H * 2, out_W * 2
|
|
s2_H, s2_W = out_H * 4, out_W * 4
|
|
s1_H, s1_W = out_H * 8, out_W * 8
|
|
|
|
seq_l = out_H * out_W
|
|
# use a shared mask for a batch images
|
|
mask = torch.zeros([1, 1, seq_l], device=x.device)
|
|
|
|
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
|
|
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
|
|
# ascend: small is keep, large is removed
|
|
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
|
|
mask.scatter_(2, mask_idx, 1)
|
|
mask = mask.reshape(1, 1, out_H, out_W)
|
|
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
|
|
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
|
|
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
|
|
|
|
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
|
|
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
|
|
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
|
|
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
|
|
|
|
return mask_s1, mask_s2, mask_s3, mask
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
mask: Optional[bool] = True) -> Tuple[torch.Tensor]:
|
|
"""Generate features for masked images.
|
|
|
|
This function generates mask and masks some patches randomly and get
|
|
the hidden features for visible patches.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
|
mask (bool, optional): To indicate whether the forward containing
|
|
``mask`` or not.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]:
|
|
- x (torch.Tensor): hidden features, which is of shape
|
|
B x L x C.
|
|
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
|
|
"""
|
|
if mask is None or False:
|
|
return super().forward(x)
|
|
|
|
else:
|
|
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(
|
|
x, self.mask_ratio)
|
|
|
|
x, _ = self.patch_embed(x)
|
|
|
|
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
|
|
x = x + self.absolute_pos_embed
|
|
x = self.drop_after_pos(x)
|
|
|
|
for idx, layer in enumerate(self.layers):
|
|
if idx == 0:
|
|
x = layer(x, attn_mask=mask_s1)
|
|
elif idx == 1:
|
|
x = layer(x, attn_mask=mask_s2)
|
|
elif idx == 2:
|
|
x = layer(x, attn_mask=mask_s3)
|
|
elif idx == 3:
|
|
x = layer(x, attn_mask=mask_s4)
|
|
|
|
x = self.norm(x)
|
|
|
|
return x, mask_s4
|
|
|
|
|
|
@MODELS.register_module()
|
|
class MixMIM(BaseSelfSupervisor):
|
|
"""MixMIM.
|
|
|
|
Implementation of `MixMIM: Mixed and Masked Image Modeling for Efficient
|
|
Visual Representation Learning. <https://arxiv.org/abs/2205.13137>`_.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: Optional[dict] = None,
|
|
head: Optional[dict] = None,
|
|
pretrained: Optional[str] = None,
|
|
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
|
init_cfg: Optional[dict] = None):
|
|
|
|
head.update(dict(patch_size=neck['encoder_stride']))
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
pretrained=pretrained,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
|
|
def extract_feat(self, inputs: torch.Tensor):
|
|
return self.backbone(inputs, mask=None)
|
|
|
|
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
inputs (torch.Tensor): The input images.
|
|
data_samples (List[DataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
|
"""
|
|
latent, mask = self.backbone(inputs)
|
|
x_rec = self.neck(latent, mask)
|
|
loss = self.head.loss(x_rec, inputs, mask)
|
|
losses = dict(loss=loss)
|
|
return losses
|