[Refactor] Add self-supervised backbones and target generators. (#1379)
* add heads * add losses * fix * remove mim head * add modified backbones and target generators * add unittest * refactor caevit * add window_size check * fix lint * apply new DataSample * fix ut error * update ut * fix ut * fix lint * Update base modules. --------- Co-authored-by: mzr1996 <mzr1996@163.com>pull/1400/head
parent
63d9f27fde
commit
e453a45d31
|
@ -47,7 +47,7 @@ repos:
|
|||
rev: v0.4.0
|
||||
hooks:
|
||||
- id: check-copyright
|
||||
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
|
||||
args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: metafile
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -9,7 +9,6 @@ from .losses import * # noqa: F401,F403
|
|||
from .necks import * # noqa: F401,F403
|
||||
from .retrievers import * # noqa: F401,F403
|
||||
from .selfsup import * # noqa: F401,F403
|
||||
from .target_generators import * # noqa: F401,F403
|
||||
from .tta import * # noqa: F401,F403
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .beit import BEiT
|
||||
from .beit import BEiTViT
|
||||
from .conformer import Conformer
|
||||
from .convmixer import ConvMixer
|
||||
from .convnext import ConvNeXt
|
||||
|
@ -106,7 +106,7 @@ __all__ = [
|
|||
'HorNet',
|
||||
'MobileViT',
|
||||
'DaViT',
|
||||
'BEiT',
|
||||
'BEiTViT',
|
||||
'RevVisionTransformer',
|
||||
'MixMIMTransformer',
|
||||
'TinyViT',
|
||||
|
|
|
@ -155,7 +155,6 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
|||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
@ -214,7 +213,7 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiT(VisionTransformer):
|
||||
class BEiTViT(VisionTransformer):
|
||||
"""Backbone for BEiT.
|
||||
|
||||
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
|
||||
|
@ -244,8 +243,10 @@ class BEiT(VisionTransformer):
|
|||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
||||
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
||||
only add leanable bias for q, v. If bias is False, it will not add
|
||||
bias for q, k, v. Default to 'qv_bias'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
|
@ -285,6 +286,7 @@ class BEiT(VisionTransformer):
|
|||
out_indices=-1,
|
||||
drop_rate=0,
|
||||
drop_path_rate=0,
|
||||
bias='qv_bias',
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=False,
|
||||
with_cls_token=True,
|
||||
|
@ -395,6 +397,7 @@ class BEiT(VisionTransformer):
|
|||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
bias=bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
|
|
|
@ -30,6 +30,8 @@ class ImageClassifier(BaseClassifier):
|
|||
- augments (List[dict]): The batch augmentation methods to use.
|
||||
More details can be found in
|
||||
:mod:`mmpretrain.model.utils.augment`.
|
||||
- probs (List[float], optional): The probability of every batch
|
||||
augmentation methods. If None, choose evenly. Defaults to None.
|
||||
|
||||
Defaults to None.
|
||||
data_preprocessor (dict, optional): The config for preprocessing input
|
||||
|
@ -51,14 +53,16 @@ class ImageClassifier(BaseClassifier):
|
|||
if pretrained is not None:
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
if data_preprocessor is None:
|
||||
data_preprocessor = {}
|
||||
# The build process is in MMEngine, so we need to add scope here.
|
||||
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')
|
||||
data_preprocessor = data_preprocessor or {}
|
||||
|
||||
if train_cfg is not None and 'augments' in train_cfg:
|
||||
# Set batch augmentations by `train_cfg`
|
||||
data_preprocessor['batch_augments'] = train_cfg
|
||||
if isinstance(data_preprocessor, dict):
|
||||
data_preprocessor.setdefault('type', 'ClsDataPreprocessor')
|
||||
data_preprocessor.setdefault('batch_augments', train_cfg)
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
elif not isinstance(data_preprocessor, nn.Module):
|
||||
raise TypeError('data_preprocessor should be a `dict` or '
|
||||
f'`nn.Module` instance, but got '
|
||||
f'{type(data_preprocessor)}')
|
||||
|
||||
super(ImageClassifier, self).__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
@ -82,16 +86,13 @@ class ImageClassifier(BaseClassifier):
|
|||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
||||
- "tensor": Forward the whole network and return tensor or tuple of
|
||||
tensor without any post-processing, same as a common nn.Module.
|
||||
- "tensor": Forward the whole network and return tensor(s) without any
|
||||
post-processing, same as a common PyTorch Module.
|
||||
- "predict": Forward and return the predictions, which are fully
|
||||
processed to a list of :obj:`DataSample`.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Note that this method doesn't handle neither back propagation nor
|
||||
optimizer updating, which are done in the :meth:`train_step`.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseSelfSupervisor
|
||||
from .beit import VQKD, BEiTPretrainViT
|
||||
from .cae import CAEViT, Encoder
|
||||
from .mae import MAEViT
|
||||
from .maskfeat import HOGGenerator, MaskFeatViT
|
||||
from .milan import CLIPGenerator, MILANViT
|
||||
from .mixmim import MixMIMPretrainTransformer
|
||||
from .mocov3 import MoCoV3ViT
|
||||
from .simmim import SimMIMSwinTransformer
|
||||
|
||||
__all__ = [
|
||||
'BaseSelfSupervisor',
|
||||
'BEiTPretrainViT',
|
||||
'VQKD',
|
||||
'CAEViT',
|
||||
'Encoder',
|
||||
'MAEViT',
|
||||
'HOGGenerator',
|
||||
'MaskFeatViT',
|
||||
'CLIPGenerator',
|
||||
'MILANViT',
|
||||
'MixMIMPretrainTransformer',
|
||||
'MoCoV3ViT',
|
||||
'SimMIMSwinTransformer',
|
||||
]
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModel
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta):
|
||||
"""BaseModel for Self-Supervised Learning.
|
||||
|
||||
All self-supervised algorithms should inherit this module.
|
||||
|
||||
Args:
|
||||
backbone (dict): The backbone module. See
|
||||
:mod:`mmpretrain.models.backbones`.
|
||||
neck (dict, optional): The neck module to process features from
|
||||
backbone. See :mod:`mmpretrain.models.necks`. Defaults to None.
|
||||
head (dict, optional): The head module to do prediction and calculate
|
||||
loss from processed features. See :mod:`mmpretrain.models.heads`.
|
||||
Notice that if the head is not set, almost all methods cannot be
|
||||
used except :meth:`extract_feat`. Defaults to None.
|
||||
target_generator: (dict, optional): The target_generator module to
|
||||
generate targets for self-supervised learning optimization, such as
|
||||
HOG, extracted features from other modules(DALL-E, CLIP), etc.
|
||||
pretrained (str, optional): The pretrained checkpoint path, support
|
||||
local path and remote path. Defaults to None.
|
||||
data_preprocessor (Union[dict, nn.Module], optional): The config for
|
||||
preprocessing input data. If None or no specified type, it will use
|
||||
"SelfSupDataPreprocessor" as type.
|
||||
See :class:`SelfSupDataPreprocessor` for more details.
|
||||
Defaults to None.
|
||||
init_cfg (dict, optional): the config to control the initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
backbone: dict,
|
||||
neck: Optional[dict] = None,
|
||||
head: Optional[dict] = None,
|
||||
target_generator: Optional[dict] = None,
|
||||
pretrained: Optional[str] = None,
|
||||
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
if pretrained is not None:
|
||||
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
||||
|
||||
data_preprocessor = data_preprocessor or {}
|
||||
if isinstance(data_preprocessor, dict):
|
||||
data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor')
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
elif not isinstance(data_preprocessor, nn.Module):
|
||||
raise TypeError('data_preprocessor should be a `dict` or '
|
||||
f'`nn.Module` instance, but got '
|
||||
f'{type(data_preprocessor)}')
|
||||
|
||||
super().__init__(
|
||||
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
|
||||
|
||||
if not isinstance(backbone, nn.Module):
|
||||
backbone = MODELS.build(backbone)
|
||||
if neck is not None and not isinstance(neck, nn.Module):
|
||||
neck = MODELS.build(neck)
|
||||
if head is not None and not isinstance(head, nn.Module):
|
||||
head = MODELS.build(head)
|
||||
if target_generator is not None and not isinstance(
|
||||
target_generator, nn.Module):
|
||||
target_generator = MODELS.build(target_generator)
|
||||
|
||||
self.backbone = backbone
|
||||
self.neck = neck
|
||||
self.head = head
|
||||
self.target_generator = target_generator
|
||||
|
||||
@property
|
||||
def with_neck(self) -> bool:
|
||||
"""Check if the model has a neck module."""
|
||||
return hasattr(self, 'neck') and self.neck is not None
|
||||
|
||||
@property
|
||||
def with_head(self) -> bool:
|
||||
"""Check if the model has a head module."""
|
||||
return hasattr(self, 'head') and self.head is not None
|
||||
|
||||
@property
|
||||
def with_target_generator(self) -> bool:
|
||||
"""Check if the model has a target_generator module."""
|
||||
return hasattr(
|
||||
self, 'target_generator') and self.target_generator is not None
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[DataSample]] = None,
|
||||
mode: str = 'tensor'):
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
||||
- "tensor": Forward the backbone network and return the feature
|
||||
tensor(s) tensor without any post-processing, same as a common
|
||||
PyTorch Module.
|
||||
- "loss": Forward and return a dict of losses according to the given
|
||||
inputs and data samples.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[DataSample], optional): The other data of
|
||||
every samples. It's required for some algorithms
|
||||
if ``mode="loss"``. Defaults to None.
|
||||
mode (str): Return what kind of value. Defaults to 'tensor'.
|
||||
|
||||
Returns:
|
||||
The return type depends on ``mode``.
|
||||
|
||||
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
|
||||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'tensor':
|
||||
feats = self.extract_feat(inputs)
|
||||
return feats
|
||||
elif mode == 'loss':
|
||||
return self.loss(inputs, data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}".')
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, inputs: torch.Tensor):
|
||||
"""Extract features from the input tensor with shape (N, C, ...).
|
||||
|
||||
The sub-classes are recommended to implement this method to extract
|
||||
features from backbone and neck.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): A batch of inputs. The shape of it should be
|
||||
``(num_samples, num_channels, *img_shape)``.
|
||||
|
||||
Returns:
|
||||
tuple | Tensor: The output feature tensor(s).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, inputs: torch.Tensor,
|
||||
data_samples: List[DataSample]) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
This is a abstract method, and subclass should overwrite this methods
|
||||
if needed.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input tensor with shape
|
||||
(N, C, ...) in general.
|
||||
data_samples (List[DataSample]): The annotation data of
|
||||
every samples.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,280 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.models import BEiTViT
|
||||
from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VQKD(BaseModule):
|
||||
"""Vector-Quantized Knowledge Distillation.
|
||||
|
||||
The module only contains encoder and VectorQuantizer part
|
||||
Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py
|
||||
|
||||
Args:
|
||||
encoder_config (dict): The config of encoder.
|
||||
decoder_config (dict, optional): The config of decoder. Currently,
|
||||
VQKD only support to build encoder. Defaults to None.
|
||||
num_embed (int): Number of embedding vectors in the codebook. Defaults
|
||||
to 8192.
|
||||
embed_dims (int) : The dimension of embedding vectors in the codebook.
|
||||
Defaults to 32.
|
||||
decay (float): The decay parameter of EMA. Defaults to 0.99.
|
||||
beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1.
|
||||
quantize_kmeans_init (bool): Whether to use k-means to initialize the
|
||||
VectorQuantizer. Defaults to True.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
encoder_config: dict,
|
||||
decoder_config: Optional[dict] = None,
|
||||
num_embed: int = 8192,
|
||||
embed_dims: int = 32,
|
||||
decay: float = 0.99,
|
||||
beta: float = 1.0,
|
||||
quantize_kmeans_init: bool = True,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.encoder = BEiTViT(**encoder_config)
|
||||
if decoder_config is not None:
|
||||
self.decoder = BEiTViT(**decoder_config)
|
||||
|
||||
self.quantize = NormEMAVectorQuantizer(
|
||||
num_embed=num_embed,
|
||||
embed_dims=embed_dims,
|
||||
beta=beta,
|
||||
decay=decay,
|
||||
kmeans_init=quantize_kmeans_init,
|
||||
)
|
||||
|
||||
# task layer
|
||||
self.encode_task_layer = nn.Sequential(
|
||||
nn.Linear(self.encoder.arch_settings['embed_dims'],
|
||||
self.encoder.arch_settings['embed_dims']), nn.Tanh(),
|
||||
nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims))
|
||||
|
||||
def get_tokens(self, x: torch.Tensor) -> dict:
|
||||
"""Get tokens for beit pre-training."""
|
||||
_, embed_ind, _ = self.encode(x)
|
||||
output = {}
|
||||
output['token'] = embed_ind.view(x.shape[0], -1)
|
||||
output['input_img'] = x
|
||||
|
||||
return output
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode the input images and get corresponding results."""
|
||||
encoder_features = self.encoder(x)[0]
|
||||
B, C, N1, N2 = encoder_features.shape
|
||||
encoder_features = encoder_features.permute(0, 2, 3,
|
||||
1).reshape(B, N1 * N2, C)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
to_quantizer_features = self.encode_task_layer(
|
||||
encoder_features.type_as(self.encode_task_layer[-1].weight))
|
||||
|
||||
N = to_quantizer_features.shape[1]
|
||||
h, w = int(math.sqrt(N)), int(math.sqrt(N))
|
||||
|
||||
to_quantizer_features = rearrange(
|
||||
to_quantizer_features, 'b (h w) c -> b c h w', h=h,
|
||||
w=w) # reshape for quantizer
|
||||
quantize, loss, embed_ind = self.quantize(to_quantizer_features)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""The forward function.
|
||||
|
||||
Currently, only support to get tokens.
|
||||
"""
|
||||
return self.get_tokens(x)['token']
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiTPretrainViT(BEiTViT):
|
||||
"""Vision Transformer for BEiT pre-training.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture. If use string,
|
||||
choose from 'small', 'base' and 'large'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **num_layers** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
- **feedforward_channels** (int): The hidden dimensions in
|
||||
feedforward modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to the most
|
||||
common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
with_cls_token (bool): Whether concatenating class token into image
|
||||
tokens as transformer input. Defaults to True.
|
||||
avg_token (bool): Whether or not to use the mean patch token for
|
||||
classification. If True, the model will only take the average
|
||||
of all patch tokens. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
``with_cls_token`` must be True. Defaults to True.
|
||||
use_abs_pos_emb (bool): Whether or not use absolute position embedding.
|
||||
Defaults to False.
|
||||
use_rel_pos_bias (bool): Whether or not use relative position bias.
|
||||
Defaults to False.
|
||||
use_shared_rel_pos_bias (bool): Whether or not use shared relative
|
||||
position bias. Defaults to True.
|
||||
layer_scale_init_value (float): The initialization value for
|
||||
the learnable scaling of attention and FFN. Defaults to 0.1.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'base',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_channels: int = 3,
|
||||
out_indices: int = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
avg_token: bool = False,
|
||||
frozen_stages: int = -1,
|
||||
output_cls_token: bool = True,
|
||||
use_abs_pos_emb: bool = False,
|
||||
use_rel_pos_bias: bool = False,
|
||||
use_shared_rel_pos_bias: bool = True,
|
||||
layer_scale_init_value: int = 0.1,
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(padding=0),
|
||||
layer_cfgs: dict = dict(),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
avg_token=avg_token,
|
||||
frozen_stages=frozen_stages,
|
||||
output_cls_token=output_cls_token,
|
||||
use_abs_pos_emb=use_abs_pos_emb,
|
||||
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
|
||||
use_rel_pos_bias=use_rel_pos_bias,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding and cls token."""
|
||||
super().init_weights()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
return
|
||||
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
trunc_normal_(self.mask_token, std=0.02)
|
||||
self.rescale_init_weight()
|
||||
|
||||
def rescale_init_weight(self) -> None:
|
||||
"""Rescale the initialized weights."""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.layers):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
mask: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""The BEiT style forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape (B x C x H x W).
|
||||
mask (torch.Tensor): Mask for input, which is of shape
|
||||
(B x patch_resolution[0] x patch_resolution[1]).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: Hidden features.
|
||||
"""
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# replace the masked visual tokens by mask_token
|
||||
B, L, _ = x.shape
|
||||
mask_token = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
|
||||
x = x * (1. - w) + mask_token * w
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
self.shared_rel_pos_bias = self.rel_pos_bias().to(
|
||||
mask.device) if self.rel_pos_bias is not None else None
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x, rel_pos_bias=self.shared_rel_pos_bias)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
|
@ -0,0 +1,348 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Part of code is modified from BEiT
|
||||
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import attr
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.models import VisionTransformer
|
||||
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
@attr.s(eq=False)
|
||||
class Conv2d(nn.Module):
|
||||
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
||||
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
||||
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
|
||||
|
||||
use_float16: bool = attr.ib(default=True)
|
||||
device: torch.device = attr.ib(default=torch.device('cpu'))
|
||||
requires_grad: bool = attr.ib(default=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
requires_grad=self.requires_grad)
|
||||
w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2))
|
||||
|
||||
b = torch.zeros((self.n_out, ),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
requires_grad=self.requires_grad)
|
||||
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_float16 and 'cuda' in self.w.device.type:
|
||||
if x.dtype != torch.float16:
|
||||
x = x.half()
|
||||
|
||||
w, b = self.w.half(), self.b.half()
|
||||
else:
|
||||
if x.dtype != torch.float32:
|
||||
x = x.float()
|
||||
|
||||
w, b = self.w, self.b
|
||||
|
||||
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
|
||||
|
||||
|
||||
@attr.s(eq=False, repr=False)
|
||||
class EncoderBlock(nn.Module):
|
||||
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
||||
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0)
|
||||
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
|
||||
|
||||
device: torch.device = attr.ib(default=None)
|
||||
requires_grad: bool = attr.ib(default=False)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
super().__init__()
|
||||
self.n_hid = self.n_out // 4
|
||||
self.post_gain = 1 / (self.n_layers**2)
|
||||
|
||||
make_conv = partial(
|
||||
Conv2d, device=self.device, requires_grad=self.requires_grad)
|
||||
self.id_path = make_conv(
|
||||
self.n_in, self.n_out,
|
||||
1) if self.n_in != self.n_out else nn.Identity()
|
||||
self.res_path = nn.Sequential(
|
||||
OrderedDict([
|
||||
('relu_1', nn.ReLU()),
|
||||
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
|
||||
('relu_2', nn.ReLU()),
|
||||
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
|
||||
('relu_3', nn.ReLU()),
|
||||
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
|
||||
('relu_4', nn.ReLU()),
|
||||
('conv_4', make_conv(self.n_hid, self.n_out, 1)),
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.id_path(x) + self.post_gain * self.res_path(x)
|
||||
|
||||
|
||||
@attr.s(eq=False, repr=False)
|
||||
@MODELS.register_module(name='DALL-E')
|
||||
class Encoder(BaseModule):
|
||||
group_count: int = 4
|
||||
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
|
||||
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
|
||||
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
|
||||
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
|
||||
|
||||
device: torch.device = attr.ib(default=torch.device('cpu'))
|
||||
requires_grad: bool = attr.ib(default=False)
|
||||
use_mixed_precision: bool = attr.ib(default=True)
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = attr.ib(default=None)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
super().__init__(init_cfg=self.init_cfg)
|
||||
|
||||
blk_range = range(self.n_blk_per_group)
|
||||
n_layers = self.group_count * self.n_blk_per_group
|
||||
make_conv = partial(
|
||||
Conv2d, device=self.device, requires_grad=self.requires_grad)
|
||||
make_blk = partial(
|
||||
EncoderBlock,
|
||||
n_layers=n_layers,
|
||||
device=self.device,
|
||||
requires_grad=self.requires_grad)
|
||||
|
||||
self.blocks = nn.Sequential(
|
||||
OrderedDict([
|
||||
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
|
||||
('group_1',
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
*[(f'block_{i + 1}',
|
||||
make_blk(1 * self.n_hid, 1 * self.n_hid))
|
||||
for i in blk_range],
|
||||
('pool', nn.MaxPool2d(kernel_size=2)),
|
||||
]))),
|
||||
('group_2',
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
*[(f'block_{i + 1}',
|
||||
make_blk(
|
||||
1 * self.n_hid if i == 0 else 2 * self.n_hid,
|
||||
2 * self.n_hid)) for i in blk_range],
|
||||
('pool', nn.MaxPool2d(kernel_size=2)),
|
||||
]))),
|
||||
('group_3',
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
*[(f'block_{i + 1}',
|
||||
make_blk(
|
||||
2 * self.n_hid if i == 0 else 4 * self.n_hid,
|
||||
4 * self.n_hid)) for i in blk_range],
|
||||
('pool', nn.MaxPool2d(kernel_size=2)),
|
||||
]))),
|
||||
('group_4',
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
*[(f'block_{i + 1}',
|
||||
make_blk(
|
||||
4 * self.n_hid if i == 0 else 8 * self.n_hid,
|
||||
8 * self.n_hid)) for i in blk_range],
|
||||
]))),
|
||||
('output',
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
('relu', nn.ReLU()),
|
||||
('conv',
|
||||
make_conv(
|
||||
8 * self.n_hid,
|
||||
self.vocab_size,
|
||||
1,
|
||||
use_float16=False)),
|
||||
]))),
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.float()
|
||||
if len(x.shape) != 4:
|
||||
raise ValueError(f'input shape {x.shape} is not 4d')
|
||||
if x.shape[1] != self.input_channels:
|
||||
raise ValueError(f'input has {x.shape[1]} channels but model \
|
||||
built for {self.input_channels}')
|
||||
if x.dtype != torch.float32:
|
||||
raise ValueError('input must have dtype torch.float32')
|
||||
|
||||
return self.blocks(x)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAEViT(VisionTransformer):
|
||||
"""Vision Transformer for CAE pre-training.
|
||||
|
||||
Rewritten version of: `An Image is Worth 16x16 Words: Transformers
|
||||
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture. Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
||||
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
||||
only add leanable bias for q, v. If bias is False, it will not add
|
||||
bias for q, k, v. Default to 'qv_bias'.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
layer_scale_init_value (float, optional): The init value of gamma in
|
||||
BEiTTransformerEncoderLayer.
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: str = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
out_indices: int = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
bias: bool = 'qv_bias',
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
output_cls_token: bool = True,
|
||||
interpolate_mode: str = 'bicubic',
|
||||
layer_scale_init_value: float = None,
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
init_cfg: dict = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
output_cls_token=output_cls_token,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
init_cfg=init_cfg)
|
||||
self.pos_embed.requires_grad = False
|
||||
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
dpr = np.linspace(0, drop_path_rate, self.num_layers)
|
||||
|
||||
# Replace original TransformerEncoderLayer with
|
||||
# BEiTTransformerEncoderLayer
|
||||
self.layers = ModuleList()
|
||||
if isinstance(layer_cfgs, dict):
|
||||
layer_cfgs = [layer_cfgs] * self.num_layers
|
||||
for i in range(self.num_layers):
|
||||
_layer_cfg = dict(
|
||||
embed_dims=self.embed_dims,
|
||||
num_heads=self.arch_settings['num_heads'],
|
||||
feedforward_channels=self.
|
||||
arch_settings['feedforward_channels'],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=None,
|
||||
# setting `use_rel_pos_bias` to False ignores the `window_size`
|
||||
use_rel_pos_bias=False,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
bias=bias,
|
||||
norm_cfg=norm_cfg)
|
||||
_layer_cfg.update(layer_cfgs[i])
|
||||
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding and cls token."""
|
||||
super().init_weights()
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# initialize position embedding in backbone
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.pos_embed.shape[-1],
|
||||
cls_token=True)
|
||||
self.pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m) -> None:
|
||||
"""Initialize the weights."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask images and get the hidden features for
|
||||
visible patches.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (torch.Tensor): Mask for input, which is of shape B x L.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: hidden features.
|
||||
"""
|
||||
x, _ = self.patch_embed(img)
|
||||
batch_size, _, dim = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
||||
|
||||
# NOTE: unmasked embeddings
|
||||
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
|
||||
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
|
||||
|
||||
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,
|
||||
dim)
|
||||
pos_embed_unmasked = pos_embed[:,
|
||||
1:][~mask].reshape(batch_size, -1, dim)
|
||||
pos_embed_unmasked = torch.cat((pos_embed[:, :1], pos_embed_unmasked),
|
||||
dim=1)
|
||||
x_unmasked = x_unmasked + pos_embed_unmasked
|
||||
|
||||
x_unmasked = self.drop_after_pos(x_unmasked)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x_unmasked = layer(x=x_unmasked, rel_pos_bias=None)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x_unmasked = self.norm1(x_unmasked)
|
||||
|
||||
return x_unmasked
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import VisionTransformer
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAEViT(VisionTransformer):
|
||||
"""Vision Transformer for MAE pre-training.
|
||||
|
||||
A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers
|
||||
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
This module implements the patch masking in MAE and initialize the
|
||||
position embedding with sine-cosine position embedding.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||||
Defaults to 0.75.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
out_indices: Union[Sequence, int] = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
output_cls_token: bool = True,
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
mask_ratio: float = 0.75,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
output_cls_token=output_cls_token,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# position embedding is not learnable during pretraining
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding and cls token."""
|
||||
super().init_weights()
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.pos_embed.shape[-1],
|
||||
cls_token=True)
|
||||
self.pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
w = self.patch_embed.projection.weight.data
|
||||
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
|
||||
torch.nn.init.normal_(self.cls_token, std=.02)
|
||||
|
||||
def random_masking(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask_ratio: float = 0.75
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate the mask for MAE Pre-training.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image with data augmentation applied, which is
|
||||
of shape B x L x C.
|
||||
mask_ratio (float): The mask ratio of total patches.
|
||||
Defaults to 0.75.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
masked image, mask and the ids to restore original image.
|
||||
- x_masked (torch.Tensor): masked image.
|
||||
- mask (torch.Tensor): mask used to mask image.
|
||||
- ids_restore (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
N, L, D = x.shape # batch, length, dim
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.argsort(
|
||||
noise, dim=1) # ascend: small is keep, large is remove
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
x_masked = torch.gather(
|
||||
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
||||
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], device=x.device)
|
||||
mask[:, :len_keep] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return x_masked, mask, ids_restore
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask and masks some patches randomly and get
|
||||
the hidden features for visible patches.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
|
||||
Hidden features, mask and the ids to restore original image.
|
||||
|
||||
- x (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- mask (torch.Tensor): mask used to mask image.
|
||||
- ids_restore (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)[0]
|
||||
# add pos embed w/o cls token
|
||||
x = x + self.pos_embed[:, 1:, :]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||||
cls_tokens = cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
# Use final norm
|
||||
x = self.norm1(x)
|
||||
|
||||
return (x, mask, ids_restore)
|
|
@ -0,0 +1,275 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.models import VisionTransformer
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HOGGenerator(BaseModule):
|
||||
"""Generate HOG feature for images.
|
||||
|
||||
This module is used in MaskFeat to generate HOG feature. The code is
|
||||
modified from file `slowfast/models/operators.py
|
||||
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
|
||||
Here is the link of `HOG wikipedia
|
||||
<https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
|
||||
|
||||
Args:
|
||||
nbins (int): Number of bin. Defaults to 9.
|
||||
pool (float): Number of cell. Defaults to 8.
|
||||
gaussian_window (int): Size of gaussian kernel. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
nbins: int = 9,
|
||||
pool: int = 8,
|
||||
gaussian_window: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.nbins = nbins
|
||||
self.pool = pool
|
||||
self.pi = math.pi
|
||||
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
|
||||
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous()
|
||||
weight_y = weight_x.transpose(2, 3).contiguous()
|
||||
self.register_buffer('weight_x', weight_x)
|
||||
self.register_buffer('weight_y', weight_y)
|
||||
|
||||
self.gaussian_window = gaussian_window
|
||||
if gaussian_window:
|
||||
gaussian_kernel = self.get_gaussian_kernel(gaussian_window,
|
||||
gaussian_window // 2)
|
||||
self.register_buffer('gaussian_kernel', gaussian_kernel)
|
||||
|
||||
def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor:
|
||||
"""Returns a 2D Gaussian kernel array."""
|
||||
|
||||
def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor:
|
||||
n = torch.arange(0, kernlen).float()
|
||||
n -= n.mean()
|
||||
n /= std
|
||||
w = torch.exp(-0.5 * n**2)
|
||||
return w
|
||||
|
||||
kernel_1d = _gaussian_fn(kernlen, std)
|
||||
kernel_2d = kernel_1d[:, None] * kernel_1d[None, :]
|
||||
return kernel_2d / kernel_2d.sum()
|
||||
|
||||
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
|
||||
"""Reshape HOG Features for output."""
|
||||
hog_feat = hog_feat.flatten(1, 2)
|
||||
self.unfold_size = hog_feat.shape[-1] // 14
|
||||
hog_feat = hog_feat.permute(0, 2, 3, 1)
|
||||
hog_feat = hog_feat.unfold(1, self.unfold_size,
|
||||
self.unfold_size).unfold(
|
||||
2, self.unfold_size, self.unfold_size)
|
||||
hog_feat = hog_feat.flatten(1, 2).flatten(2)
|
||||
return hog_feat
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate hog feature for each batch images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images of shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Hog features.
|
||||
"""
|
||||
# input is RGB image with shape [B 3 H W]
|
||||
self.h, self.w = x.size(-2), x.size(-1)
|
||||
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
|
||||
gx_rgb = F.conv2d(
|
||||
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
|
||||
gy_rgb = F.conv2d(
|
||||
x, self.weight_y, bias=None, stride=1, padding=0, groups=3)
|
||||
norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)
|
||||
phase = torch.atan2(gx_rgb, gy_rgb)
|
||||
phase = phase / self.pi * self.nbins # [-9, 9]
|
||||
|
||||
b, c, h, w = norm_rgb.shape
|
||||
out = torch.zeros((b, c, self.nbins, h, w),
|
||||
dtype=torch.float,
|
||||
device=x.device)
|
||||
phase = phase.view(b, c, 1, h, w)
|
||||
norm_rgb = norm_rgb.view(b, c, 1, h, w)
|
||||
if self.gaussian_window:
|
||||
if h != self.gaussian_window:
|
||||
assert h % self.gaussian_window == 0, 'h {} gw {}'.format(
|
||||
h, self.gaussian_window)
|
||||
repeat_rate = h // self.gaussian_window
|
||||
temp_gaussian_kernel = self.gaussian_kernel.repeat(
|
||||
[repeat_rate, repeat_rate])
|
||||
else:
|
||||
temp_gaussian_kernel = self.gaussian_kernel
|
||||
norm_rgb *= temp_gaussian_kernel
|
||||
|
||||
out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)
|
||||
|
||||
out = out.unfold(3, self.pool, self.pool)
|
||||
out = out.unfold(4, self.pool, self.pool)
|
||||
out = out.sum(dim=[-1, -2])
|
||||
|
||||
self.out = F.normalize(out, p=2, dim=2)
|
||||
|
||||
return self._reshape(self.out)
|
||||
|
||||
def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
|
||||
"""Generate HOG image according to HOG features."""
|
||||
assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
|
||||
'Check the input batch size and the channcel number, only support'\
|
||||
'"batch_size = 1".'
|
||||
hog_image = np.zeros([self.h, self.w])
|
||||
cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
|
||||
cell_width = self.pool / 2
|
||||
max_mag = np.array(cell_gradient).max()
|
||||
angle_gap = 360 / self.nbins
|
||||
|
||||
for x in range(cell_gradient.shape[1]):
|
||||
for y in range(cell_gradient.shape[2]):
|
||||
cell_grad = cell_gradient[:, x, y]
|
||||
cell_grad /= max_mag
|
||||
angle = 0
|
||||
for magnitude in cell_grad:
|
||||
angle_radian = math.radians(angle)
|
||||
x1 = int(x * self.pool +
|
||||
magnitude * cell_width * math.cos(angle_radian))
|
||||
y1 = int(y * self.pool +
|
||||
magnitude * cell_width * math.sin(angle_radian))
|
||||
x2 = int(x * self.pool -
|
||||
magnitude * cell_width * math.cos(angle_radian))
|
||||
y2 = int(y * self.pool -
|
||||
magnitude * cell_width * math.sin(angle_radian))
|
||||
magnitude = 0 if magnitude < 0 else magnitude
|
||||
cv2.line(hog_image, (y1, x1), (y2, x2),
|
||||
int(255 * math.sqrt(magnitude)))
|
||||
angle += angle_gap
|
||||
return hog_image
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MaskFeatViT(VisionTransformer):
|
||||
"""Vision Transformer for MaskFeat pre-training.
|
||||
|
||||
A PyTorch implement of: `Masked Feature Prediction for Self-Supervised
|
||||
Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
out_indices: Union[Sequence, int] = -1,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
final_norm: bool = True,
|
||||
output_cls_token: bool = True,
|
||||
interpolate_mode: str = 'bicubic',
|
||||
patch_cfg: dict = dict(),
|
||||
layer_cfgs: dict = dict(),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
final_norm=final_norm,
|
||||
output_cls_token=output_cls_token,
|
||||
interpolate_mode=interpolate_mode,
|
||||
patch_cfg=patch_cfg,
|
||||
layer_cfgs=layer_cfgs,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.mask_token = nn.parameter.Parameter(
|
||||
torch.zeros(1, 1, self.embed_dims), requires_grad=True)
|
||||
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, mask token and cls token."""
|
||||
super().init_weights()
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
|
||||
nn.init.trunc_normal_(self.cls_token, std=.02)
|
||||
nn.init.trunc_normal_(self.mask_token, std=.02)
|
||||
nn.init.trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m: torch.nn.Module) -> None:
|
||||
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate features for masked images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images.
|
||||
mask (torch.Tensor): Input masks.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Features with cls_tokens.
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)[0]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
B, L, _ = x.shape
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
mask = mask.flatten(1).unsqueeze(-1)
|
||||
x = x * (1 - mask.int()) + mask_tokens * mask
|
||||
|
||||
# append cls token
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.runner.checkpoint import _load_checkpoint
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_clip_model
|
||||
from .mae import MAEViT
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CLIPGenerator(BaseModule):
|
||||
"""Get the features and attention from the last layer of CLIP.
|
||||
|
||||
This module is used to generate target features in masked image modeling.
|
||||
|
||||
Args:
|
||||
tokenizer_path (str): The path of the checkpoint of CLIP.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_path: str) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = build_clip_model(
|
||||
_load_checkpoint(tokenizer_path), False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the features and attention from the last layer of CLIP.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input image, which is of shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
The features and attention from the last layer of CLIP,
|
||||
which are of shape (N, L, C) and (N, L, L), respectively.
|
||||
"""
|
||||
# use the visual branch of CLIP to get the features
|
||||
clip_features = self.tokenizer.encode_image(x)
|
||||
return clip_features
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MILANViT(MAEViT):
|
||||
"""Vision Transformer for MILAN pre-training.
|
||||
|
||||
Implementation of the encoder for `MILAN: Masked Image Pretraining on
|
||||
Language Assisted Representation <https://arxiv.org/abs/2208.06049>`_.
|
||||
|
||||
This module inherits from MAEViT and only overrides the forward function
|
||||
and replace random masking with attention masking.
|
||||
"""
|
||||
|
||||
def attention_masking(
|
||||
self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate attention mask for MILAN.
|
||||
|
||||
This is what is different from MAEViT, which uses random masking.
|
||||
Attention masking generates attention mask for MILAN, according to
|
||||
importance. The higher the importance, the more likely the patch is
|
||||
kept.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x L x C.
|
||||
mask_ratio (float): The ratio of patches to be masked.
|
||||
importance (torch.Tensor): Importance of each patch, which is of
|
||||
shape B x L.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, ...]:
|
||||
- ``x_masked``: masked image
|
||||
- ``ids_restore``: the ids to restore original image
|
||||
- ``ids_keep``: ids of the kept patches
|
||||
- ``ids_dump``: ids of the removed patches
|
||||
"""
|
||||
N, L, D = x.shape # batch, length, dim
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
noise = importance.to(x.device) # large is keep, small is remove
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.multinomial(noise, L, replacement=False)
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
ids_dump = ids_shuffle[:, len_keep:]
|
||||
x_masked = torch.gather(
|
||||
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
||||
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], device=x.device)
|
||||
mask[:, :len_keep] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return x_masked, ids_restore, ids_keep, ids_dump
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, importance: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask and masks some patches randomly and get
|
||||
the hidden features for visible patches. The mask is generated by
|
||||
importance. The higher the importance, the more likely the patch is
|
||||
kept. The importance is calculated by CLIP. The higher the CLIP score,
|
||||
the more likely the patch is kept. The CLIP score is calculated by
|
||||
by cross attention between the class token and all other tokens from
|
||||
the last layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
importance (torch.Tensor): Importance of each patch, which is of
|
||||
shape B x L.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, ...]:
|
||||
masked image, the ids to restore original image, ids of the
|
||||
kept patches, ids of the removed patches.
|
||||
- x (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ids_restore (torch.Tensor): ids to restore original image.
|
||||
- ids_keep (torch.Tensor): ids of the kept patches.
|
||||
- ids_dump (torch.Tensor): ids of the removed patches.
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)[0]
|
||||
# add pos embed w/o cls token
|
||||
x = x + self.pos_embed[:, 1:, :]
|
||||
|
||||
# masking: length -> length * mask_ratio
|
||||
x, ids_restore, ids_keep, ids_dump = self.attention_masking(
|
||||
x, self.mask_ratio, importance)
|
||||
|
||||
# append cls token
|
||||
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
||||
cls_tokens = cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
# Use final norm
|
||||
x = self.norm1(x)
|
||||
|
||||
return x, ids_restore, ids_keep, ids_dump
|
|
@ -0,0 +1,200 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from mmpretrain.models.backbones import MixMIMTransformer
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMPretrainTransformer(MixMIMTransformer):
|
||||
"""MixMIM backbone for MixMIM pre-training.
|
||||
|
||||
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
|
||||
Modeling for Efficient Visual Representation Learning
|
||||
<https://arxiv.org/abs/2205.13137>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): MixMIM architecture. If use string,
|
||||
choose from 'base','large' and 'huge'.
|
||||
If use dict, it should have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (int): The number of transformer encoder layers.
|
||||
- **num_heads** (int): The number of heads in attention modules.
|
||||
|
||||
Defaults to 'base'.
|
||||
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
|
||||
img_size (int | tuple): The expected input image shape. Because we
|
||||
support dynamic input shape, just set the argument to mlp_ratio
|
||||
the most common input image shape. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size in patch embedding.
|
||||
Defaults to 16.
|
||||
in_channels (int): The num of input channels. Defaults to 3.
|
||||
window_size (list): The height and width of the window.
|
||||
qkv_bias (bool): Whether to add bias for qkv in attention modules.
|
||||
Defaults to True.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to an empty dict.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
attn_drop_rate (float): Attention drop rate. Defaults to 0.
|
||||
use_checkpoint (bool): Whether use the checkpoint to
|
||||
reduce GPU memory cost
|
||||
range_mask_ratio (float): The range of mask ratio.
|
||||
Defaults to 0.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'base',
|
||||
mlp_ratio: float = 4,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 4,
|
||||
in_channels: int = 3,
|
||||
window_size: List = [14, 14, 14, 7],
|
||||
qkv_bias: bool = True,
|
||||
patch_cfg: dict = dict(),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
use_checkpoint: bool = False,
|
||||
range_mask_ratio: float = 0.0,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
mlp_ratio=mlp_ratio,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
window_size=window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
patch_cfg=patch_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
use_checkpoint=use_checkpoint,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.range_mask_ratio = range_mask_ratio
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize position embedding, patch embedding."""
|
||||
super(MixMIMTransformer, self).init_weights()
|
||||
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.absolute_pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.absolute_pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
# we use xavier_uniform following official JAX ViT:
|
||||
torch.nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5):
|
||||
"""Generate the mask for MixMIM Pretraining.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Image with data augmentation applied, which is
|
||||
of shape B x L x C.
|
||||
mask_ratio (float): The mask ratio of total patches.
|
||||
Defaults to 0.5.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
- mask_s1 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 8.
|
||||
- mask_s2 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 4.
|
||||
- mask_s3 (torch.Tensor): mask with stride of
|
||||
self.encoder_stride // 2.
|
||||
- mask (torch.Tensor): mask with stride of
|
||||
self.encoder_stride.
|
||||
"""
|
||||
|
||||
B, C, H, W = x.shape
|
||||
out_H = H // self.encoder_stride
|
||||
out_W = W // self.encoder_stride
|
||||
s3_H, s3_W = out_H * 2, out_W * 2
|
||||
s2_H, s2_W = out_H * 4, out_W * 4
|
||||
s1_H, s1_W = out_H * 8, out_W * 8
|
||||
|
||||
seq_l = out_H * out_W
|
||||
# use a shared mask for a batch images
|
||||
mask = torch.zeros([1, 1, seq_l], device=x.device)
|
||||
|
||||
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
|
||||
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
|
||||
# ascend: small is keep, large is removed
|
||||
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
|
||||
mask.scatter_(2, mask_idx, 1)
|
||||
mask = mask.reshape(1, 1, out_H, out_W)
|
||||
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
|
||||
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
|
||||
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
|
||||
|
||||
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
|
||||
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
|
||||
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
|
||||
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
|
||||
|
||||
return mask_s1, mask_s2, mask_s3, mask
|
||||
|
||||
def forward(self, x: torch.Tensor, mask_ratio=0.5):
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask and masks some patches randomly and get
|
||||
the hidden features for visible patches.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- x (torch.Tensor): hidden features, which is of shape
|
||||
B x L x C.
|
||||
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
|
||||
"""
|
||||
|
||||
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio)
|
||||
|
||||
x, _ = self.patch_embed(x)
|
||||
|
||||
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if idx == 0:
|
||||
x = layer(x, attn_mask=mask_s1)
|
||||
elif idx == 1:
|
||||
x = layer(x, attn_mask=mask_s2)
|
||||
elif idx == 2:
|
||||
x = layer(x, attn_mask=mask_s3)
|
||||
elif idx == 3:
|
||||
x = layer(x, attn_mask=mask_s4)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, mask_s4
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks.transformer import PatchEmbed
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmpretrain.models.backbones import VisionTransformer
|
||||
from mmpretrain.models.utils import (build_2d_sincos_position_embedding,
|
||||
to_2tuple)
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MoCoV3ViT(VisionTransformer):
|
||||
"""Vision Transformer for MoCoV3 pre-training.
|
||||
|
||||
A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for
|
||||
Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
|
||||
|
||||
Part of the code is modified from:
|
||||
`<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_.
|
||||
|
||||
Args:
|
||||
stop_grad_conv1 (bool): whether to stop the gradient of
|
||||
convolution layer in `PatchEmbed`. Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
stop_grad_conv1: bool = False,
|
||||
frozen_stages: int = -1,
|
||||
norm_eval: bool = False,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None,
|
||||
**kwargs) -> None:
|
||||
|
||||
# add MoCoV3 ViT-small arch
|
||||
self.arch_zoo.update(
|
||||
dict.fromkeys(
|
||||
['mocov3-s', 'mocov3-small'], {
|
||||
'embed_dims': 384,
|
||||
'num_layers': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 1536,
|
||||
}))
|
||||
|
||||
super().__init__(init_cfg=init_cfg, **kwargs)
|
||||
self.patch_size = kwargs['patch_size']
|
||||
self.frozen_stages = frozen_stages
|
||||
self.norm_eval = norm_eval
|
||||
self.init_cfg = init_cfg
|
||||
|
||||
if isinstance(self.patch_embed, PatchEmbed):
|
||||
if stop_grad_conv1:
|
||||
self.patch_embed.projection.weight.requires_grad = False
|
||||
self.patch_embed.projection.bias.requires_grad = False
|
||||
|
||||
self._freeze_stages()
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding, qkv layers and cls
|
||||
token."""
|
||||
super().init_weights()
|
||||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
|
||||
# Use fixed 2D sin-cos position embedding
|
||||
pos_emb = build_2d_sincos_position_embedding(
|
||||
patches_resolution=self.patch_resolution,
|
||||
embed_dims=self.embed_dims,
|
||||
cls_token=True)
|
||||
self.pos_embed.data.copy_(pos_emb)
|
||||
self.pos_embed.requires_grad = False
|
||||
|
||||
# xavier_uniform initialization for PatchEmbed
|
||||
if isinstance(self.patch_embed, PatchEmbed):
|
||||
val = math.sqrt(
|
||||
6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) +
|
||||
self.embed_dims))
|
||||
nn.init.uniform_(self.patch_embed.projection.weight, -val, val)
|
||||
nn.init.zeros_(self.patch_embed.projection.bias)
|
||||
|
||||
# initialization for linear layers
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
if 'qkv' in name:
|
||||
# treat the weights of Q, K, V separately
|
||||
val = math.sqrt(
|
||||
6. /
|
||||
float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
||||
nn.init.uniform_(m.weight, -val, val)
|
||||
else:
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
nn.init.zeros_(m.bias)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
|
||||
def _freeze_stages(self) -> None:
|
||||
"""Freeze patch_embed layer, some parameters and stages."""
|
||||
if self.frozen_stages >= 0:
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.cls_token.requires_grad = False
|
||||
self.pos_embed.requires_grad = False
|
||||
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if i == (self.num_layers) and self.final_norm:
|
||||
for param in getattr(self, 'norm1').parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True) -> None:
|
||||
super().train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.models import SwinTransformer
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SimMIMSwinTransformer(SwinTransformer):
|
||||
"""Swin Transformer for SimMIM pre-training.
|
||||
|
||||
Args:
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture
|
||||
Defaults to 'T'.
|
||||
img_size (int | tuple): The size of input image.
|
||||
Defaults to 224.
|
||||
in_channels (int): The num of input channels.
|
||||
Defaults to 3.
|
||||
drop_rate (float): Dropout rate after embedding.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
Defaults to 0.1.
|
||||
out_indices (tuple): Layers to be outputted. Defaults to (3, ).
|
||||
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
||||
the patch embedding. Defaults to False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
Defaults to False.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Defaults to False.
|
||||
norm_cfg (dict): Config dict for normalization layer at end
|
||||
of backone. Defaults to dict(type='LN')
|
||||
stage_cfgs (Sequence | dict): Extra config dict for each
|
||||
stage. Defaults to empty dict.
|
||||
patch_cfg (dict): Extra config dict for patch embedding.
|
||||
Defaults to empty dict.
|
||||
pad_small_map (bool): If True, pad the small feature map to the window
|
||||
size, which is common used in detection and segmentation. If False,
|
||||
avoid shifting window and shrink the window size to the size of
|
||||
feature map, which is common used in classification.
|
||||
Defaults to False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'T',
|
||||
img_size: Union[Tuple[int, int], int] = 224,
|
||||
in_channels: int = 3,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
out_indices: tuple = (3, ),
|
||||
use_abs_pos_embed: bool = False,
|
||||
with_cp: bool = False,
|
||||
frozen_stages: bool = -1,
|
||||
norm_eval: bool = False,
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
stage_cfgs: Union[Sequence, dict] = dict(),
|
||||
patch_cfg: dict = dict(),
|
||||
pad_small_map: bool = False,
|
||||
init_cfg: Optional[dict] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
out_indices=out_indices,
|
||||
use_abs_pos_embed=use_abs_pos_embed,
|
||||
with_cp=with_cp,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_eval=norm_eval,
|
||||
norm_cfg=norm_cfg,
|
||||
stage_cfgs=stage_cfgs,
|
||||
patch_cfg=patch_cfg,
|
||||
pad_small_map=pad_small_map,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize weights."""
|
||||
super().init_weights()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress default init if use pretrained model.
|
||||
return
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
||||
|
||||
trunc_normal_(self.mask_token, mean=0, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initialize weights."""
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
mask: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
This function generates mask images and get the hidden features for
|
||||
them.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images.
|
||||
mask (torch.Tensor): Masks used to construct masked images.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing features from multi-stages.
|
||||
"""
|
||||
x, hw_shape = self.patch_embed(x)
|
||||
|
||||
assert mask is not None
|
||||
B, L, _ = x.shape
|
||||
|
||||
mask_token = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
|
||||
x = x * (1. - w) + mask_token * w
|
||||
|
||||
if self.use_abs_pos_embed:
|
||||
x = x + self.absolute_pos_embed
|
||||
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x, hw_shape = stage(x, hw_shape)
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
out = norm_layer(x)
|
||||
out = out.view(-1, *hw_shape,
|
||||
stage.out_channels).permute(0, 3, 1,
|
||||
2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
|
@ -5,6 +5,7 @@ from .attention import (BEiTAttention, ChannelMultiheadAttention,
|
|||
ShiftWindowMSA, WindowMSA, WindowMSAV2)
|
||||
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .clip_generator_helper import build_clip_model
|
||||
from .data_preprocessor import ClsDataPreprocessor
|
||||
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
|
||||
resize_relative_position_bias_table)
|
||||
|
@ -17,6 +18,7 @@ from .position_encoding import (ConditionalPositionEncoding,
|
|||
PositionEncodingFourier,
|
||||
build_2d_sincos_position_embedding)
|
||||
from .se_layer import SELayer
|
||||
from .vector_quantizer import NormEMAVectorQuantizer
|
||||
|
||||
__all__ = [
|
||||
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
|
||||
|
@ -28,5 +30,6 @@ __all__ = [
|
|||
'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention',
|
||||
'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d',
|
||||
'build_norm_layer', 'CrossMultiheadAttention',
|
||||
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention'
|
||||
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention',
|
||||
'NormEMAVectorQuantizer', 'build_clip_model'
|
||||
]
|
||||
|
|
|
@ -567,7 +567,7 @@ class BEiTAttention(BaseModule):
|
|||
Args:
|
||||
embed_dims (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
window_size (tuple[int, int]): The height and width of the window.
|
||||
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
||||
if False, use shared relative position bias defined in backbone.
|
||||
bias (str): The option to add leanable bias for q, k, v. If bias is
|
||||
|
@ -606,6 +606,10 @@ class BEiTAttention(BaseModule):
|
|||
self._init_qv_bias()
|
||||
qkv_bias = False
|
||||
|
||||
if window_size is None:
|
||||
assert not use_rel_pos_bias
|
||||
else:
|
||||
assert isinstance(window_size, tuple)
|
||||
self.window_size = window_size
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self._init_rel_pos_embedding()
|
||||
|
|
|
@ -0,0 +1,391 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from https://github.com/zejiangh/MILAN
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine.logging import MMLogger
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
"""A faster version of GELU."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
"""Residual Attention Block (RAB).
|
||||
|
||||
This module implements the same function as the MultiheadAttention in
|
||||
MMClassification, but with a different interface, which is mainly used
|
||||
in CLIP.
|
||||
|
||||
Args:
|
||||
d_model (int): The feature dimension.
|
||||
n_head (int): The number of attention heads.
|
||||
attn_mask (torch.Tensor, optional): The attention mask.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
return_attention: bool = False) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
self.return_attention = return_attention
|
||||
|
||||
def attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Attention function."""
|
||||
self.attn_mask = self.attn_mask.to(
|
||||
dtype=x.dtype,
|
||||
device=x.device) if self.attn_mask is not None else None
|
||||
if self.return_attention:
|
||||
return self.attn(
|
||||
x,
|
||||
x,
|
||||
x,
|
||||
need_weights=self.return_attention,
|
||||
attn_mask=self.attn_mask)
|
||||
else:
|
||||
return self.attn(
|
||||
x,
|
||||
x,
|
||||
x,
|
||||
need_weights=self.return_attention,
|
||||
attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Forward function."""
|
||||
if self.return_attention:
|
||||
x_, attention = self.attention(self.ln_1(x))
|
||||
x = x + x_
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x, attention
|
||||
else:
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer.
|
||||
|
||||
Both visual and text branches use this transformer.
|
||||
|
||||
Args:
|
||||
width (int): The feature dimension.
|
||||
layers (int): The number of layers.
|
||||
heads (int): The number of attention heads.
|
||||
attn_mask (torch.Tensor, optional): The attention mask.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList()
|
||||
for _ in range(layers - 1):
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(width, heads, attn_mask))
|
||||
self.resblocks.append(
|
||||
ResidualAttentionBlock(
|
||||
width, heads, attn_mask, return_attention=True))
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward function."""
|
||||
z = []
|
||||
for idx, blk in enumerate(self.resblocks):
|
||||
if idx < self.layers - 1:
|
||||
x = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
else:
|
||||
x, attention = blk(x)
|
||||
z.append(x.permute(1, 0, 2))
|
||||
return x, attention, z
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer for CLIP.
|
||||
|
||||
Args:
|
||||
input_resolution (int): The image size.
|
||||
patch_size (int): The patch size.
|
||||
width (int): The feature dimension.
|
||||
layers (int): The number of layers.
|
||||
heads (int): The number of attention heads.
|
||||
out_dim (int): The output dimension.
|
||||
fineturn (bool): Whether to fineturn the model.
|
||||
average_target (bool): Whether to average the target.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_resolution: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
output_dim: int,
|
||||
finetune=False,
|
||||
average_targets: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
scale = width**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.finetune = finetune
|
||||
if finetune is False:
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
self.average_targets = average_targets
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function."""
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([
|
||||
self.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
||||
],
|
||||
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x, attention, z = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x)
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x, attention
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
"""CLIP.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The embedding dimension.
|
||||
image_resolution (int): The image size.
|
||||
vision_layers (int): The number of layers in the vision transformer.
|
||||
vision_width (int): The feature dimension in the vision transformer.
|
||||
vision_patch_size (int): The patch size in the vision transformer.
|
||||
context_length (int): The context length.
|
||||
vocab_size (int): The vocabulary size.
|
||||
transformer_width (int): The feature dimension in the text transformer.
|
||||
transformer_heads (int): The number of attention heads in the
|
||||
text transformer.
|
||||
transformer_layers (int): The number of layers in the text transformer.
|
||||
fineturn (bool): Whether to fineturn the model.
|
||||
average_target (bool): Whether to average the target.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int,
|
||||
finetune: bool = False,
|
||||
average_targets: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim,
|
||||
finetune=finetune,
|
||||
average_targets=average_targets,
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask())
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self) -> None:
|
||||
"""Initialize the parameters.
|
||||
|
||||
The pretrained weight will override the initialized parameters by this
|
||||
function.
|
||||
"""
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self) -> torch.Tensor:
|
||||
"""Build the attention mask."""
|
||||
# lazily create causal attention mask, with full attention between the
|
||||
# vision tokens pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""Get the dtype."""
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self,
|
||||
image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode the image.
|
||||
|
||||
Get the feature and attention mask from the last layer of the visual
|
||||
branch of CLIP.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The image tensor with shape NCHW.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask.
|
||||
"""
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
|
||||
def build_clip_model(state_dict: dict,
|
||||
finetune: bool = False,
|
||||
average_targets: int = 1) -> nn.Module:
|
||||
"""Build the CLIP model.
|
||||
|
||||
Args:
|
||||
state_dict (dict): The pretrained state dict.
|
||||
finetune (bool): Whether to fineturn the model.
|
||||
average_targets (bool): Whether to average the target.
|
||||
|
||||
Returns:
|
||||
nn.Module: The CLIP model.
|
||||
"""
|
||||
vit = 'visual.proj' in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict['visual.conv1.weight'].shape[0]
|
||||
vision_layers = len([
|
||||
k for k in state_dict.keys()
|
||||
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
|
||||
])
|
||||
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
|
||||
grid_size = round(
|
||||
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
|
||||
embed_dim = state_dict['text_projection'].shape[1]
|
||||
context_length = state_dict['positional_embedding'].shape[0]
|
||||
vocab_size = state_dict['token_embedding.weight'].shape[0]
|
||||
transformer_width = state_dict['ln_final.weight'].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith('transformer.resblocks')))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution,
|
||||
vision_layers,
|
||||
vision_width,
|
||||
vision_patch_size,
|
||||
context_length,
|
||||
vocab_size,
|
||||
transformer_width,
|
||||
transformer_heads,
|
||||
transformer_layers,
|
||||
finetune,
|
||||
average_targets,
|
||||
)
|
||||
|
||||
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
MMLogger.get_current_instance().info(f'Load CLIP model: {msg}')
|
||||
return model.eval()
|
|
@ -81,7 +81,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
if batch_augments is not None:
|
||||
if batch_augments:
|
||||
self.batch_augments = RandomBatchAugment(**batch_augments)
|
||||
if not self.to_onehot:
|
||||
from mmengine.logging import MMLogger
|
||||
|
|
|
@ -0,0 +1,232 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Modified from
|
||||
# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from mmengine.dist import all_reduce
|
||||
|
||||
|
||||
def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
|
||||
decay: torch.Tensor) -> None:
|
||||
"""Update moving average."""
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
|
||||
decay: torch.Tensor) -> None:
|
||||
"""Update moving average with norm data."""
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1))
|
||||
|
||||
|
||||
def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
|
||||
"""Sample vectors according to the given number."""
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num, ), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples: torch.Tensor,
|
||||
num_clusters: int,
|
||||
num_iters: int = 10,
|
||||
use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run k-means algorithm."""
|
||||
dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if use_cosine_sim:
|
||||
dists = samples @ means.t()
|
||||
else:
|
||||
diffs = rearrange(samples, 'n d -> n () d') \
|
||||
- rearrange(means, 'c d -> () c d')
|
||||
dists = -(diffs**2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
if use_cosine_sim:
|
||||
new_means = F.normalize(new_means, p=2, dim=-1)
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
class EmbeddingEMA(nn.Module):
|
||||
"""The codebook of embedding vectors.
|
||||
|
||||
Args:
|
||||
num_tokens (int): Number of embedding vectors in the codebook.
|
||||
codebook_dim (int) : The dimension of embedding vectors in the
|
||||
codebook.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the
|
||||
VectorQuantizer. Defaults to True.
|
||||
codebook_init_path (str): The initialization checkpoint for codebook.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_tokens: int,
|
||||
codebook_dim: int,
|
||||
kmeans_init: bool = True,
|
||||
codebook_init_path: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.num_tokens = num_tokens
|
||||
self.codebook_dim = codebook_dim
|
||||
if codebook_init_path is None:
|
||||
if not kmeans_init:
|
||||
weight = torch.randn(num_tokens, codebook_dim)
|
||||
weight = F.normalize(weight, p=2, dim=-1)
|
||||
else:
|
||||
weight = torch.zeros(num_tokens, codebook_dim)
|
||||
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
||||
else:
|
||||
print(f'load init codebook weight from {codebook_init_path}')
|
||||
codebook_ckpt_weight = torch.load(
|
||||
codebook_init_path, map_location='cpu')
|
||||
weight = codebook_ckpt_weight.clone()
|
||||
self.register_buffer('initted', torch.Tensor([True]))
|
||||
|
||||
self.weight = nn.Parameter(weight, requires_grad=False)
|
||||
self.update = True
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_embed_(self, data: torch.Tensor) -> None:
|
||||
"""Initialize embedding vectors of codebook."""
|
||||
if self.initted:
|
||||
return
|
||||
print('Performing K-means init for codebook')
|
||||
embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
||||
self.weight.data.copy_(embed)
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
def forward(self, embed_id: torch.Tensor) -> torch.Tensor:
|
||||
"""Get embedding vectors."""
|
||||
return F.embedding(embed_id, self.weight)
|
||||
|
||||
|
||||
class NormEMAVectorQuantizer(nn.Module):
|
||||
"""Normed EMA vector quantizer module.
|
||||
|
||||
Args:
|
||||
num_embed (int): Number of embedding vectors in the codebook. Defaults
|
||||
to 8192.
|
||||
embed_dims (int) : The dimension of embedding vectors in the codebook.
|
||||
Defaults to 32.
|
||||
beta (float): The mutiplier for VectorQuantizer embedding loss.
|
||||
Defaults to 1.
|
||||
decay (float): The decay parameter of EMA. Defaults to 0.99.
|
||||
statistic_code_usage (bool): Whether to use cluster_size to record
|
||||
statistic. Defaults to True.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the
|
||||
VectorQuantizer. Defaults to True.
|
||||
codebook_init_path (str): The initialization checkpoint for codebook.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embed: int,
|
||||
embed_dims: int,
|
||||
beta: float,
|
||||
decay: float = 0.99,
|
||||
statistic_code_usage: bool = True,
|
||||
kmeans_init: bool = True,
|
||||
codebook_init_path: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.codebook_dim = embed_dims
|
||||
self.num_tokens = num_embed
|
||||
self.beta = beta
|
||||
self.decay = decay
|
||||
|
||||
# learnable = True if orthogonal_reg_weight > 0 else False
|
||||
self.embedding = EmbeddingEMA(
|
||||
num_tokens=self.num_tokens,
|
||||
codebook_dim=self.codebook_dim,
|
||||
kmeans_init=kmeans_init,
|
||||
codebook_init_path=codebook_init_path)
|
||||
|
||||
self.statistic_code_usage = statistic_code_usage
|
||||
if statistic_code_usage:
|
||||
self.register_buffer('cluster_size', torch.zeros(num_embed))
|
||||
|
||||
def reset_cluster_size(self, device):
|
||||
|
||||
if self.statistic_code_usage:
|
||||
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
||||
self.cluster_size = self.cluster_size.to(device)
|
||||
|
||||
def forward(self, z):
|
||||
"""Forward function."""
|
||||
# reshape z -> (batch, height, width, channel)
|
||||
z = rearrange(z, 'b c h w -> b h w c')
|
||||
z = F.normalize(z, p=2, dim=-1)
|
||||
z_flattened = z.reshape(-1, self.codebook_dim)
|
||||
|
||||
self.embedding.init_embed_(z_flattened)
|
||||
|
||||
# 'n d -> d n'
|
||||
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
||||
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
||||
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
|
||||
|
||||
encoding_indices = torch.argmin(d, dim=1)
|
||||
|
||||
z_q = self.embedding(encoding_indices).view(z.shape)
|
||||
|
||||
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
||||
|
||||
if not self.training:
|
||||
with torch.no_grad():
|
||||
cluster_size = encodings.sum(0)
|
||||
all_reduce(cluster_size)
|
||||
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
||||
|
||||
if self.training and self.embedding.update:
|
||||
# update cluster size with EMA
|
||||
bins = encodings.sum(0)
|
||||
all_reduce(bins)
|
||||
ema_inplace(self.cluster_size, bins, self.decay)
|
||||
|
||||
zero_mask = (bins == 0)
|
||||
bins = bins.masked_fill(zero_mask, 1.)
|
||||
|
||||
embed_sum = z_flattened.t() @ encodings
|
||||
all_reduce(embed_sum)
|
||||
|
||||
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
||||
embed_normalized = F.normalize(embed_normalized, p=2, dim=-1)
|
||||
embed_normalized = torch.where(zero_mask[..., None],
|
||||
self.embedding.weight,
|
||||
embed_normalized)
|
||||
|
||||
# Update embedding vectors with EMA
|
||||
norm_ema_inplace(self.embedding.weight, embed_normalized,
|
||||
self.decay)
|
||||
|
||||
# compute loss for embedding
|
||||
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
||||
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# reshape back to match original input shape
|
||||
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
||||
return z_q, loss, encoding_indices
|
|
@ -4,7 +4,7 @@ from unittest import TestCase
|
|||
|
||||
import torch
|
||||
|
||||
from mmpretrain.models.backbones import BEiT
|
||||
from mmpretrain.models.backbones import BEiTViT
|
||||
|
||||
|
||||
class TestBEiT(TestCase):
|
||||
|
@ -18,7 +18,7 @@ class TestBEiT(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['arch'] = 'unknown'
|
||||
BEiT(**cfg)
|
||||
BEiTViT(**cfg)
|
||||
|
||||
# Test invalid custom arch
|
||||
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
|
||||
|
@ -28,7 +28,7 @@ class TestBEiT(TestCase):
|
|||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}
|
||||
BEiT(**cfg)
|
||||
BEiTViT(**cfg)
|
||||
|
||||
# Test custom arch
|
||||
cfg = deepcopy(self.cfg)
|
||||
|
@ -38,7 +38,7 @@ class TestBEiT(TestCase):
|
|||
'num_heads': 16,
|
||||
'feedforward_channels': 1024
|
||||
}
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
self.assertEqual(model.embed_dims, 128)
|
||||
self.assertEqual(model.num_layers, 24)
|
||||
self.assertIsNone(model.pos_embed)
|
||||
|
@ -51,21 +51,21 @@ class TestBEiT(TestCase):
|
|||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = {1: 1}
|
||||
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
|
||||
BEiT(**cfg)
|
||||
BEiTViT(**cfg)
|
||||
cfg['out_indices'] = [0, 13]
|
||||
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
|
||||
BEiT(**cfg)
|
||||
BEiTViT(**cfg)
|
||||
|
||||
# Test pos_embed
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['use_abs_pos_emb'] = True
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
self.assertEqual(model.pos_embed.shape, (1, 197, 768))
|
||||
|
||||
# Test model structure
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['drop_path_rate'] = 0.1
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
self.assertEqual(len(model.layers), 12)
|
||||
dpr_inc = 0.1 / (12 - 1)
|
||||
dpr = 0
|
||||
|
@ -85,7 +85,7 @@ class TestBEiT(TestCase):
|
|||
# test with output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['output_cls_token'] = True
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
|
@ -95,7 +95,7 @@ class TestBEiT(TestCase):
|
|||
|
||||
# test without output_cls_token
|
||||
cfg = deepcopy(self.cfg)
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
|
@ -105,7 +105,7 @@ class TestBEiT(TestCase):
|
|||
# test without average
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['avg_token'] = False
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 1)
|
||||
|
@ -115,7 +115,7 @@ class TestBEiT(TestCase):
|
|||
# Test forward with multi out indices
|
||||
cfg = deepcopy(self.cfg)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = BEiT(**cfg)
|
||||
model = BEiTViT(**cfg)
|
||||
outs = model(imgs)
|
||||
self.assertIsInstance(outs, tuple)
|
||||
self.assertEqual(len(outs), 3)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import BEiTPretrainViT
|
||||
|
||||
backbone = dict(
|
||||
arch='base',
|
||||
patch_size=16,
|
||||
drop_path_rate=0.1,
|
||||
final_norm=True,
|
||||
layer_scale_init_value=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_beit_pretrain_vit():
|
||||
beit_backbone = BEiTPretrainViT(**backbone)
|
||||
beit_backbone.init_weights()
|
||||
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_mask = torch.zeros((2, 196))
|
||||
fake_mask[:, 75:150] = 1
|
||||
fake_outputs = beit_backbone(fake_inputs, fake_mask)
|
||||
|
||||
assert list(fake_outputs[0].shape) == [2, 197, 768]
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import CAEViT
|
||||
|
||||
backbone = dict(arch='b', patch_size=16, layer_scale_init_value=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_cae_vit():
|
||||
cae_backbone = CAEViT(**backbone)
|
||||
cae_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_mask = torch.zeros((2, 196)).bool()
|
||||
fake_mask[:, 75:150] = 1
|
||||
fake_outputs = cae_backbone(fake_inputs, fake_mask)
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 122, 768]
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import MAEViT
|
||||
|
||||
backbone = dict(arch='b', patch_size=16, mask_ratio=0.75)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae_vit():
|
||||
mae_backbone = MAEViT(**backbone)
|
||||
mae_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_outputs = mae_backbone(fake_inputs)[0]
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 50, 768]
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import MaskFeatViT
|
||||
|
||||
backbone = dict(arch='b', patch_size=16)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_maskfeat_vit():
|
||||
maskfeat_backbone = MaskFeatViT(**backbone)
|
||||
maskfeat_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_mask = torch.randn((2, 14, 14))
|
||||
fake_outputs = maskfeat_backbone(fake_inputs, fake_mask)
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 197, 768]
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
|
||||
from mmpretrain.models import MoCoV3ViT
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_vision_transformer():
|
||||
vit = MoCoV3ViT(
|
||||
arch='mocov3-small', patch_size=16, frozen_stages=12, norm_eval=True)
|
||||
vit.init_weights()
|
||||
vit.train()
|
||||
|
||||
for p in vit.parameters():
|
||||
assert p.requires_grad is False
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import SimMIMSwinTransformer
|
||||
|
||||
backbone = dict(
|
||||
arch='B', img_size=192, stage_cfgs=dict(block_cfgs=dict(window_size=6)))
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_cae_vit():
|
||||
simmim_backbone = SimMIMSwinTransformer(**backbone)
|
||||
simmim_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 192, 192))
|
||||
fake_mask = torch.rand((2, 48, 48))
|
||||
fake_outputs = simmim_backbone(fake_inputs, fake_mask)[0]
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 1024, 6, 6]
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import VQKD, Encoder, HOGGenerator
|
||||
|
||||
|
||||
class TestDALLE(TestCase):
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_dalle(self):
|
||||
model = Encoder()
|
||||
fake_inputs = torch.rand((2, 3, 112, 112))
|
||||
fake_outputs = model(fake_inputs)
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 8192, 14, 14]
|
||||
|
||||
|
||||
class TestHOGGenerator(TestCase):
|
||||
|
||||
def test_hog_generator(self):
|
||||
hog_generator = HOGGenerator()
|
||||
|
||||
fake_input = torch.randn((2, 3, 224, 224))
|
||||
fake_output = hog_generator(fake_input)
|
||||
assert list(fake_output.shape) == [2, 196, 108]
|
||||
|
||||
fake_hog_out = hog_generator.out[0].unsqueeze(0)
|
||||
fake_hog_img = hog_generator.generate_hog_image(fake_hog_out)
|
||||
assert fake_hog_img.shape == (224, 224)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
fake_hog_img = hog_generator.generate_hog_image(
|
||||
hog_generator.out[0])
|
||||
|
||||
|
||||
class TestVQKD(TestCase):
|
||||
|
||||
ENCODER_CFG = dict(
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
out_indices=-1,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN', eps=1e-6),
|
||||
final_norm=True,
|
||||
with_cls_token=True,
|
||||
avg_token=False,
|
||||
frozen_stages=-1,
|
||||
output_cls_token=False,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
layer_scale_init_value=0.,
|
||||
interpolate_mode='bicubic',
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_vqkd(self):
|
||||
model = VQKD(encoder_config=self.ENCODER_CFG)
|
||||
fake_inputs = torch.rand((2, 3, 224, 224))
|
||||
fake_outputs = model(fake_inputs)
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 196]
|
Loading…
Reference in New Issue