195 lines
7.5 KiB
Python
195 lines
7.5 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from mmpretrain.models.selfsup.mae import MAE, MAEViT
|
||
|
from mmpretrain.registry import MODELS
|
||
|
from mmpretrain.structures import DataSample
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class MFFViT(MAEViT):
|
||
|
"""Vision Transformer for MFF Pretraining.
|
||
|
|
||
|
This class inherits all these functionalities from ``MAEViT``, and
|
||
|
add multi-level feature fusion to it. For more details, you can
|
||
|
refer to `Improving Pixel-based MIM by Reducing Wasted Modeling
|
||
|
Capability`.
|
||
|
|
||
|
Args:
|
||
|
arch (str | dict): Vision Transformer architecture
|
||
|
Default: 'b'
|
||
|
img_size (int | tuple): Input image size
|
||
|
patch_size (int | tuple): The patch size
|
||
|
out_indices (Sequence | int): Output from which stages.
|
||
|
Defaults to -1, means the last stage.
|
||
|
drop_rate (float): Probability of an element to be zeroed.
|
||
|
Defaults to 0.
|
||
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||
|
norm_cfg (dict): Config dict for normalization layer.
|
||
|
Defaults to ``dict(type='LN')``.
|
||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||
|
final feature map. Defaults to True.
|
||
|
out_type (str): The type of output features. Please choose from
|
||
|
|
||
|
- ``"cls_token"``: The class token tensor with shape (B, C).
|
||
|
- ``"featmap"``: The feature map tensor from the patch tokens
|
||
|
with shape (B, C, H, W).
|
||
|
- ``"avg_featmap"``: The global averaged feature map tensor
|
||
|
with shape (B, C).
|
||
|
- ``"raw"``: The raw feature tensor includes patch tokens and
|
||
|
class tokens with shape (B, L, C).
|
||
|
|
||
|
It only works without input mask. Defaults to ``"avg_featmap"``.
|
||
|
interpolate_mode (str): Select the interpolate mode for position
|
||
|
embeding vector resize. Defaults to "bicubic".
|
||
|
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||
|
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||
|
encoder. Defaults to an empty dict.
|
||
|
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||
|
Defaults to 0.75.
|
||
|
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||
|
dict. Defaults to None.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
arch: Union[str, dict] = 'b',
|
||
|
img_size: int = 224,
|
||
|
patch_size: int = 16,
|
||
|
out_indices: Union[Sequence, int] = -1,
|
||
|
drop_rate: float = 0,
|
||
|
drop_path_rate: float = 0,
|
||
|
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||
|
final_norm: bool = True,
|
||
|
out_type: str = 'raw',
|
||
|
interpolate_mode: str = 'bicubic',
|
||
|
patch_cfg: dict = dict(),
|
||
|
layer_cfgs: dict = dict(),
|
||
|
mask_ratio: float = 0.75,
|
||
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||
|
super().__init__(
|
||
|
arch=arch,
|
||
|
img_size=img_size,
|
||
|
patch_size=patch_size,
|
||
|
out_indices=out_indices,
|
||
|
drop_rate=drop_rate,
|
||
|
drop_path_rate=drop_path_rate,
|
||
|
norm_cfg=norm_cfg,
|
||
|
final_norm=final_norm,
|
||
|
out_type=out_type,
|
||
|
interpolate_mode=interpolate_mode,
|
||
|
patch_cfg=patch_cfg,
|
||
|
layer_cfgs=layer_cfgs,
|
||
|
mask_ratio=mask_ratio,
|
||
|
init_cfg=init_cfg)
|
||
|
proj_layers = [
|
||
|
torch.nn.Linear(self.embed_dims, self.embed_dims)
|
||
|
for _ in range(len(self.out_indices) - 1)
|
||
|
]
|
||
|
self.proj_layers = torch.nn.ModuleList(proj_layers)
|
||
|
self.proj_weights = torch.nn.Parameter(
|
||
|
torch.ones(len(self.out_indices)).view(-1, 1, 1, 1))
|
||
|
if len(self.out_indices) == 1:
|
||
|
self.proj_weights.requires_grad = False
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
x: torch.Tensor,
|
||
|
mask: Optional[bool] = True
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""Generate features for masked images.
|
||
|
|
||
|
The function supports two kind of forward behaviors. If the ``mask`` is
|
||
|
``True``, the function will generate mask to masking some patches
|
||
|
randomly and get the hidden features for visible patches, which means
|
||
|
the function will be executed as masked imagemodeling pre-training;
|
||
|
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||
|
call ``super().forward()``, which extract features from images without
|
||
|
mask.
|
||
|
|
||
|
|
||
|
Args:
|
||
|
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||
|
mask (bool, optional): To indicate whether the forward function
|
||
|
generating ``mask`` or not.
|
||
|
|
||
|
Returns:
|
||
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||
|
mask and the ids to restore original image.
|
||
|
|
||
|
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||
|
B x (L * mask_ratio) x C.
|
||
|
- ``mask`` (torch.Tensor): mask used to mask image.
|
||
|
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||
|
"""
|
||
|
if mask is None or False:
|
||
|
return super().forward(x)
|
||
|
|
||
|
else:
|
||
|
B = x.shape[0]
|
||
|
x = self.patch_embed(x)[0]
|
||
|
# add pos embed w/o cls token
|
||
|
x = x + self.pos_embed[:, 1:, :]
|
||
|
|
||
|
# masking: length -> length * mask_ratio
|
||
|
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
|
||
|
|
||
|
# append cls token
|
||
|
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||
|
cls_tokens = cls_token.expand(B, -1, -1)
|
||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||
|
|
||
|
res = []
|
||
|
for i, layer in enumerate(self.layers):
|
||
|
x = layer(x)
|
||
|
if i in self.out_indices:
|
||
|
if i != self.out_indices[-1]:
|
||
|
proj_x = self.proj_layers[self.out_indices.index(i)](x)
|
||
|
else:
|
||
|
proj_x = x
|
||
|
res.append(proj_x)
|
||
|
res = torch.stack(res)
|
||
|
proj_weights = F.softmax(self.proj_weights, dim=0)
|
||
|
res = res * proj_weights
|
||
|
res = res.sum(dim=0)
|
||
|
|
||
|
# Use final norm
|
||
|
x = self.norm1(res)
|
||
|
return (x, mask, ids_restore, proj_weights.view(-1))
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class MFF(MAE):
|
||
|
"""MFF.
|
||
|
|
||
|
Implementation of `Improving Pixel-based MIM by Reducing Wasted Modeling
|
||
|
Capability`.
|
||
|
"""
|
||
|
|
||
|
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||
|
**kwargs) -> Dict[str, torch.Tensor]:
|
||
|
"""The forward function in training.
|
||
|
|
||
|
Args:
|
||
|
inputs (torch.Tensor): The input images.
|
||
|
data_samples (List[DataSample]): All elements required
|
||
|
during the forward function.
|
||
|
|
||
|
Returns:
|
||
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||
|
"""
|
||
|
# ids_restore: the same as that in original repo, which is used
|
||
|
# to recover the original order of tokens in decoder.
|
||
|
latent, mask, ids_restore, weights = self.backbone(inputs)
|
||
|
pred = self.neck(latent, ids_restore)
|
||
|
loss = self.head.loss(pred, inputs, mask)
|
||
|
weight_params = {
|
||
|
f'weight_{i}': weights[i]
|
||
|
for i in range(weights.size(0))
|
||
|
}
|
||
|
losses = dict(loss=loss)
|
||
|
losses.update(weight_params)
|
||
|
return losses
|