[Feature]: Add MFF
parent
340d187765
commit
2706f5c8c2
|
@ -0,0 +1,24 @@
|
|||
_base_ = '../mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py'
|
||||
|
||||
randomness = dict(seed=2, diff_rank_seed=True)
|
||||
|
||||
# dataset config
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='ToPIL', to_rgb=True),
|
||||
dict(type='torchvision/Resize', size=224),
|
||||
dict(
|
||||
type='torchvision/RandomCrop',
|
||||
size=224,
|
||||
padding=4,
|
||||
padding_mode='reflect'),
|
||||
dict(type='torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='ToNumpy', to_bgr=True),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
# model config
|
||||
model = dict(
|
||||
type='MFF', backbone=dict(type='MFFViT', out_indices=[1, 3, 5, 7, 9, 11]))
|
|
@ -9,6 +9,7 @@ from .eva import EVA
|
|||
from .itpn import iTPN, iTPNHiViT
|
||||
from .mae import MAE, MAEHiViT, MAEViT
|
||||
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
|
||||
from .mff import MFF, MFFViT
|
||||
from .milan import MILAN, CLIPGenerator, MILANViT
|
||||
from .mixmim import MixMIM, MixMIMPretrainTransformer
|
||||
from .moco import MoCo
|
||||
|
@ -20,37 +21,10 @@ from .spark import SparK
|
|||
from .swav import SwAV
|
||||
|
||||
__all__ = [
|
||||
'BaseSelfSupervisor',
|
||||
'BEiTPretrainViT',
|
||||
'VQKD',
|
||||
'CAEPretrainViT',
|
||||
'DALLEEncoder',
|
||||
'MAEViT',
|
||||
'MAEHiViT',
|
||||
'iTPNHiViT',
|
||||
'iTPN',
|
||||
'HOGGenerator',
|
||||
'MaskFeatViT',
|
||||
'CLIPGenerator',
|
||||
'MILANViT',
|
||||
'MixMIMPretrainTransformer',
|
||||
'MoCoV3ViT',
|
||||
'SimMIMSwinTransformer',
|
||||
'MoCo',
|
||||
'MoCoV3',
|
||||
'BYOL',
|
||||
'SimCLR',
|
||||
'SimSiam',
|
||||
'BEiT',
|
||||
'CAE',
|
||||
'MAE',
|
||||
'MaskFeat',
|
||||
'MILAN',
|
||||
'MixMIM',
|
||||
'SimMIM',
|
||||
'EVA',
|
||||
'DenseCL',
|
||||
'BarlowTwins',
|
||||
'SwAV',
|
||||
'SparK',
|
||||
'BaseSelfSupervisor', 'BEiTPretrainViT', 'VQKD', 'CAEPretrainViT',
|
||||
'DALLEEncoder', 'MAEViT', 'MAEHiViT', 'iTPNHiViT', 'iTPN', 'HOGGenerator',
|
||||
'MaskFeatViT', 'CLIPGenerator', 'MILANViT', 'MixMIMPretrainTransformer',
|
||||
'MoCoV3ViT', 'SimMIMSwinTransformer', 'MoCo', 'MoCoV3', 'BYOL', 'SimCLR',
|
||||
'SimSiam', 'BEiT', 'CAE', 'MAE', 'MaskFeat', 'MILAN', 'MixMIM', 'SimMIM',
|
||||
'EVA', 'DenseCL', 'BarlowTwins', 'SwAV', 'SparK', 'MFF', 'MFFViT'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmpretrain.models.selfsup.mae import MAE, MAEViT
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MFFViT(MAEViT):
|
||||
"""Vision Transformer for MFF Pretraining.
|
||||
|
||||
This class inherits all these functionalities from ``MAEViT``, and
|
||||
add multi-level feature fusion to it. For more details, you can
|
||||
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
|
||||
Capability`.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
out_type (str): The type of output features. Please choose from
|
||||
|
||||
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||||
- ``"featmap"``: The feature map tensor from the patch tokens
|
||||
with shape (B, C, H, W).
|
||||
- ``"avg_featmap"``: The global averaged feature map tensor
|
||||
with shape (B, C).
|
||||
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||||
class tokens with shape (B, L, C).
|
||||
|
||||
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||||
Defaults to 0.75.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
out_indices: Union[Sequence, int] = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
out_type: str = 'raw',
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
mask_ratio: float = 0.75,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
out_type=out_type,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
mask_ratio=mask_ratio,
|
||||
init_cfg=init_cfg)
|
||||
proj_layers = [
|
||||
torch.nn.Linear(self.embed_dims, self.embed_dims)
|
||||
for _ in range(len(self.out_indices) - 1)
|
||||
]
|
||||
self.proj_layers = torch.nn.ModuleList(proj_layers)
|
||||
self.proj_weights = torch.nn.Parameter(
|
||||
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
|
||||
if len(self.out_indices) == 1:
|
||||
self.proj_weights.requires_grad = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[bool] = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
if mask is None or False:
|
||||
return super().forward(x)
|
||||
|
||||
else:
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)[0]
|
||||
# add pos embed w/o cls token
|
||||
x = x + self.pos_embed[:, 1:, :]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||||
cls_tokens = cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
res = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
if i != self.out_indices[-1]:
|
||||
proj_x = self.proj_layers[self.out_indices.index(i)](x)
|
||||
else:
|
||||
proj_x = x
|
||||
res.append(proj_x)
|
||||
res = torch.stack(res)
|
||||
proj_weights = F.softmax(self.proj_weights, dim=0)
|
||||
res = res * proj_weights
|
||||
res = res.sum(dim=0)
|
||||
|
||||
# Use final norm
|
||||
x = self.norm1(res)
|
||||
|
||||
return (x, mask, ids_restore, proj_weights.view(-1))
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MFF(MAE):
|
||||
"""MAE.
|
||||
|
||||
Implementation of `Masked Autoencoders Are Scalable Vision Learners
|
||||
<https://arxiv.org/abs/2111.06377>`_.
|
||||
"""
|
||||
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input images.
|
||||
data_samples (List[DataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
# ids_restore: the same as that in original repo, which is used
|
||||
# to recover the original order of tokens in decoder.
|
||||
latent, mask, ids_restore, weights = self.backbone(inputs)
|
||||
pred = self.neck(latent, ids_restore)
|
||||
loss = self.head.loss(pred, inputs, mask)
|
||||
weight_params = {
|
||||
f'weight_{i}': weights[i]
|
||||
for i in range(weights.size(0))
|
||||
}
|
||||
losses = dict(loss=loss)
|
||||
losses.update(weight_params)
|
||||
return losses
|
Loading…
Reference in New Issue