[Feature]: Add MFF

pull/1725/head
liuyuan 2023-07-21 19:23:32 +08:00
parent 340d187765
commit 2706f5c8c2
3 changed files with 226 additions and 33 deletions

View File

@ -0,0 +1,24 @@
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'
randomness = dict(seed=2, diff_rank_seed=True)
# dataset config
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='ToPIL', to_rgb=True),
dict(type='torchvision/Resize', size=224),
dict(
type='torchvision/RandomCrop',
size=224,
padding=4,
padding_mode='reflect'),
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
dict(type='ToNumpy', to_bgr=True),
dict(type='PackInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
# model config
model = dict(
type='MFF', backbone=dict(type='MFFViT', out_indices=[1, 3, 5, 7, 9, 11]))

View File

@ -9,6 +9,7 @@ from .eva import EVA
from .itpn import iTPN, iTPNHiViT
from .mae import MAE, MAEHiViT, MAEViT
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
from .mff import MFF, MFFViT
from .milan import MILAN, CLIPGenerator, MILANViT
from .mixmim import MixMIM, MixMIMPretrainTransformer
from .moco import MoCo
@ -20,37 +21,10 @@ from .spark import SparK
from .swav import SwAV
__all__ = [
'BaseSelfSupervisor',
'BEiTPretrainViT',
'VQKD',
'CAEPretrainViT',
'DALLEEncoder',
'MAEViT',
'MAEHiViT',
'iTPNHiViT',
'iTPN',
'HOGGenerator',
'MaskFeatViT',
'CLIPGenerator',
'MILANViT',
'MixMIMPretrainTransformer',
'MoCoV3ViT',
'SimMIMSwinTransformer',
'MoCo',
'MoCoV3',
'BYOL',
'SimCLR',
'SimSiam',
'BEiT',
'CAE',
'MAE',
'MaskFeat',
'MILAN',
'MixMIM',
'SimMIM',
'EVA',
'DenseCL',
'BarlowTwins',
'SwAV',
'SparK',
'BaseSelfSupervisor', 'BEiTPretrainViT', 'VQKD', 'CAEPretrainViT',
'DALLEEncoder', 'MAEViT', 'MAEHiViT', 'iTPNHiViT', 'iTPN', 'HOGGenerator',
'MaskFeatViT', 'CLIPGenerator', 'MILANViT', 'MixMIMPretrainTransformer',
'MoCoV3ViT', 'SimMIMSwinTransformer', 'MoCo', 'MoCoV3', 'BYOL', 'SimCLR',
'SimSiam', 'BEiT', 'CAE', 'MAE', 'MaskFeat', 'MILAN', 'MixMIM', 'SimMIM',
'EVA', 'DenseCL', 'BarlowTwins', 'SwAV', 'SparK', 'MFF', 'MFFViT'
]

View File

@ -0,0 +1,195 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from mmpretrain.models.selfsup.mae import MAE, MAEViT
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class MFFViT(MAEViT):
"""Vision Transformer for MFF Pretraining.
This class inherits all these functionalities from ``MAEViT``, and
add multi-level feature fusion to it. For more details, you can
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
Capability`.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
It only works without input mask. Defaults to ``"avg_featmap"``.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
out_type: str = 'raw',
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
mask_ratio: float = 0.75,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
out_type=out_type,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
mask_ratio=mask_ratio,
init_cfg=init_cfg)
proj_layers = [
torch.nn.Linear(self.embed_dims, self.embed_dims)
for _ in range(len(self.out_indices) - 1)
]
self.proj_layers = torch.nn.ModuleList(proj_layers)
self.proj_weights = torch.nn.Parameter(
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
if len(self.out_indices) == 1:
self.proj_weights.requires_grad = False
def forward(
self,
x: torch.Tensor,
mask: Optional[bool] = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
res = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
if i != self.out_indices[-1]:
proj_x = self.proj_layers[self.out_indices.index(i)](x)
else:
proj_x = x
res.append(proj_x)
res = torch.stack(res)
proj_weights = F.softmax(self.proj_weights, dim=0)
res = res * proj_weights
res = res.sum(dim=0)
# Use final norm
x = self.norm1(res)
return (x, mask, ids_restore, proj_weights.view(-1))
@MODELS.register_module()
class MFF(MAE):
"""MAE.
Implementation of `Masked Autoencoders Are Scalable Vision Learners
<https://arxiv.org/abs/2111.06377>`_.
"""
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (torch.Tensor): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# ids_restore: the same as that in original repo, which is used
# to recover the original order of tokens in decoder.
latent, mask, ids_restore, weights = self.backbone(inputs)
pred = self.neck(latent, ids_restore)
loss = self.head.loss(pred, inputs, mask)
weight_params = {
f'weight_{i}': weights[i]
for i in range(weights.size(0))
}
losses = dict(loss=loss)
losses.update(weight_params)
return losses