[Refactor] Add self-supervised backbones and target generators. (#1379)

* add heads

* add losses

* fix

* remove mim head

* add modified backbones and target generators

* add unittest

* refactor caevit

* add window_size check

* fix lint

* apply new DataSample

* fix ut error

* update ut

* fix ut

* fix lint

* Update base modules.

---------

Co-authored-by: mzr1996 <mzr1996@163.com>
pull/1400/head
Yixiao Fang 2023-02-28 15:59:17 +08:00 committed by GitHub
parent 63d9f27fde
commit e453a45d31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 2776 additions and 35 deletions

View File

@ -47,7 +47,7 @@ repos:
rev: v0.4.0
hooks:
- id: check-copyright
args: ["mmcls", "tests", "demo", "tools", "--excludes", "mmcls/.mim/", "--ignore-file-not-found-error"]
args: ["mmpretrain", "tests", "demo", "tools", "--excludes", "mmpretrain/.mim/", "--ignore-file-not-found-error"]
- repo: local
hooks:
- id: metafile

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -9,7 +9,6 @@ from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .retrievers import * # noqa: F401,F403
from .selfsup import * # noqa: F401,F403
from .target_generators import * # noqa: F401,F403
from .tta import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .beit import BEiT
from .beit import BEiTViT
from .conformer import Conformer
from .convmixer import ConvMixer
from .convnext import ConvNeXt
@ -106,7 +106,7 @@ __all__ = [
'HorNet',
'MobileViT',
'DaViT',
'BEiT',
'BEiTViT',
'RevVisionTransformer',
'MixMIMTransformer',
'TinyViT',

View File

@ -155,7 +155,6 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
drop_path_rate=0.,
drop_rate=0.,
num_fcs=num_fcs,
qkv_bias=bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
@ -214,7 +213,7 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
@MODELS.register_module()
class BEiT(VisionTransformer):
class BEiTViT(VisionTransformer):
"""Backbone for BEiT.
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
@ -244,8 +243,10 @@ class BEiT(VisionTransformer):
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
@ -285,6 +286,7 @@ class BEiT(VisionTransformer):
out_indices=-1,
drop_rate=0,
drop_path_rate=0,
bias='qv_bias',
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=False,
with_cls_token=True,
@ -395,6 +397,7 @@ class BEiT(VisionTransformer):
use_rel_pos_bias=use_rel_pos_bias,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
bias=bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))

View File

@ -30,6 +30,8 @@ class ImageClassifier(BaseClassifier):
- augments (List[dict]): The batch augmentation methods to use.
More details can be found in
:mod:`mmpretrain.model.utils.augment`.
- probs (List[float], optional): The probability of every batch
augmentation methods. If None, choose evenly. Defaults to None.
Defaults to None.
data_preprocessor (dict, optional): The config for preprocessing input
@ -51,14 +53,16 @@ class ImageClassifier(BaseClassifier):
if pretrained is not None:
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
if data_preprocessor is None:
data_preprocessor = {}
# The build process is in MMEngine, so we need to add scope here.
data_preprocessor.setdefault('type', 'mmpretrain.ClsDataPreprocessor')
data_preprocessor = data_preprocessor or {}
if train_cfg is not None and 'augments' in train_cfg:
# Set batch augmentations by `train_cfg`
data_preprocessor['batch_augments'] = train_cfg
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'ClsDataPreprocessor')
data_preprocessor.setdefault('batch_augments', train_cfg)
data_preprocessor = MODELS.build(data_preprocessor)
elif not isinstance(data_preprocessor, nn.Module):
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
super(ImageClassifier, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
@ -82,16 +86,13 @@ class ImageClassifier(BaseClassifier):
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "tensor": Forward the whole network and return tensor(s) without any
post-processing, same as a common PyTorch Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`DataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.

View File

@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseSelfSupervisor
from .beit import VQKD, BEiTPretrainViT
from .cae import CAEViT, Encoder
from .mae import MAEViT
from .maskfeat import HOGGenerator, MaskFeatViT
from .milan import CLIPGenerator, MILANViT
from .mixmim import MixMIMPretrainTransformer
from .mocov3 import MoCoV3ViT
from .simmim import SimMIMSwinTransformer
__all__ = [
'BaseSelfSupervisor',
'BEiTPretrainViT',
'VQKD',
'CAEViT',
'Encoder',
'MAEViT',
'HOGGenerator',
'MaskFeatViT',
'CLIPGenerator',
'MILANViT',
'MixMIMPretrainTransformer',
'MoCoV3ViT',
'SimMIMSwinTransformer',
]

View File

@ -0,0 +1,164 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import List, Optional, Union
import torch
from mmengine.model import BaseModel
from torch import nn
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
class BaseSelfSupervisor(BaseModel, metaclass=ABCMeta):
"""BaseModel for Self-Supervised Learning.
All self-supervised algorithms should inherit this module.
Args:
backbone (dict): The backbone module. See
:mod:`mmpretrain.models.backbones`.
neck (dict, optional): The neck module to process features from
backbone. See :mod:`mmpretrain.models.necks`. Defaults to None.
head (dict, optional): The head module to do prediction and calculate
loss from processed features. See :mod:`mmpretrain.models.heads`.
Notice that if the head is not set, almost all methods cannot be
used except :meth:`extract_feat`. Defaults to None.
target_generator: (dict, optional): The target_generator module to
generate targets for self-supervised learning optimization, such as
HOG, extracted features from other modules(DALL-E, CLIP), etc.
pretrained (str, optional): The pretrained checkpoint path, support
local path and remote path. Defaults to None.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"SelfSupDataPreprocessor" as type.
See :class:`SelfSupDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: Optional[dict] = None,
head: Optional[dict] = None,
target_generator: Optional[dict] = None,
pretrained: Optional[str] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
init_cfg: Optional[dict] = None):
if pretrained is not None:
init_cfg = dict(type='Pretrained', checkpoint=pretrained)
data_preprocessor = data_preprocessor or {}
if isinstance(data_preprocessor, dict):
data_preprocessor.setdefault('type', 'SelfSupDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
elif not isinstance(data_preprocessor, nn.Module):
raise TypeError('data_preprocessor should be a `dict` or '
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
if not isinstance(backbone, nn.Module):
backbone = MODELS.build(backbone)
if neck is not None and not isinstance(neck, nn.Module):
neck = MODELS.build(neck)
if head is not None and not isinstance(head, nn.Module):
head = MODELS.build(head)
if target_generator is not None and not isinstance(
target_generator, nn.Module):
target_generator = MODELS.build(target_generator)
self.backbone = backbone
self.neck = neck
self.head = head
self.target_generator = target_generator
@property
def with_neck(self) -> bool:
"""Check if the model has a neck module."""
return hasattr(self, 'neck') and self.neck is not None
@property
def with_head(self) -> bool:
"""Check if the model has a head module."""
return hasattr(self, 'head') and self.head is not None
@property
def with_target_generator(self) -> bool:
"""Check if the model has a target_generator module."""
return hasattr(
self, 'target_generator') and self.target_generator is not None
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[DataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the backbone network and return the feature
tensor(s) tensor without any post-processing, same as a common
PyTorch Module.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample], optional): The other data of
every samples. It's required for some algorithms
if ``mode="loss"``. Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'tensor':
feats = self.extract_feat(inputs)
return feats
elif mode == 'loss':
return self.loss(inputs, data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
@abstractmethod
def extract_feat(self, inputs: torch.Tensor):
"""Extract features from the input tensor with shape (N, C, ...).
The sub-classes are recommended to implement this method to extract
features from backbone and neck.
Args:
inputs (Tensor): A batch of inputs. The shape of it should be
``(num_samples, num_channels, *img_shape)``.
Returns:
tuple | Tensor: The output feature tensor(s).
"""
raise NotImplementedError
@abstractmethod
def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
"""Calculate losses from a batch of inputs and data samples.
This is a abstract method, and subclass should overwrite this methods
if needed.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, ...) in general.
data_samples (List[DataSample]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
raise NotImplementedError

View File

@ -0,0 +1,280 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import torch
from einops import rearrange
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from torch import nn
from mmpretrain.models import BEiTViT
from mmpretrain.models.utils import NormEMAVectorQuantizer, resize_pos_embed
from mmpretrain.registry import MODELS
@MODELS.register_module()
class VQKD(BaseModule):
"""Vector-Quantized Knowledge Distillation.
The module only contains encoder and VectorQuantizer part
Modified from https://github.com/microsoft/unilm/blob/master/beit2/modeling_vqkd.py
Args:
encoder_config (dict): The config of encoder.
decoder_config (dict, optional): The config of decoder. Currently,
VQKD only support to build encoder. Defaults to None.
num_embed (int): Number of embedding vectors in the codebook. Defaults
to 8192.
embed_dims (int) : The dimension of embedding vectors in the codebook.
Defaults to 32.
decay (float): The decay parameter of EMA. Defaults to 0.99.
beta (float): The mutiplier for VectorQuantizer loss. Defaults to 1.
quantize_kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
""" # noqa: E501
def __init__(self,
encoder_config: dict,
decoder_config: Optional[dict] = None,
num_embed: int = 8192,
embed_dims: int = 32,
decay: float = 0.99,
beta: float = 1.0,
quantize_kmeans_init: bool = True,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.encoder = BEiTViT(**encoder_config)
if decoder_config is not None:
self.decoder = BEiTViT(**decoder_config)
self.quantize = NormEMAVectorQuantizer(
num_embed=num_embed,
embed_dims=embed_dims,
beta=beta,
decay=decay,
kmeans_init=quantize_kmeans_init,
)
# task layer
self.encode_task_layer = nn.Sequential(
nn.Linear(self.encoder.arch_settings['embed_dims'],
self.encoder.arch_settings['embed_dims']), nn.Tanh(),
nn.Linear(self.encoder.arch_settings['embed_dims'], embed_dims))
def get_tokens(self, x: torch.Tensor) -> dict:
"""Get tokens for beit pre-training."""
_, embed_ind, _ = self.encode(x)
output = {}
output['token'] = embed_ind.view(x.shape[0], -1)
output['input_img'] = x
return output
def encode(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Encode the input images and get corresponding results."""
encoder_features = self.encoder(x)[0]
B, C, N1, N2 = encoder_features.shape
encoder_features = encoder_features.permute(0, 2, 3,
1).reshape(B, N1 * N2, C)
with torch.cuda.amp.autocast(enabled=False):
to_quantizer_features = self.encode_task_layer(
encoder_features.type_as(self.encode_task_layer[-1].weight))
N = to_quantizer_features.shape[1]
h, w = int(math.sqrt(N)), int(math.sqrt(N))
to_quantizer_features = rearrange(
to_quantizer_features, 'b (h w) c -> b c h w', h=h,
w=w) # reshape for quantizer
quantize, loss, embed_ind = self.quantize(to_quantizer_features)
return quantize, embed_ind, loss
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""The forward function.
Currently, only support to get tokens.
"""
return self.get_tokens(x)['token']
@MODELS.register_module()
class BEiTPretrainViT(BEiTViT):
"""Vision Transformer for BEiT pre-training.
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
avg_token (bool): Whether or not to use the mean patch token for
classification. If True, the model will only take the average
of all patch tokens. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
use_abs_pos_emb (bool): Whether or not use absolute position embedding.
Defaults to False.
use_rel_pos_bias (bool): Whether or not use relative position bias.
Defaults to False.
use_shared_rel_pos_bias (bool): Whether or not use shared relative
position bias. Defaults to True.
layer_scale_init_value (float): The initialization value for
the learnable scaling of attention and FFN. Defaults to 0.1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: str = 'base',
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
out_indices: int = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
avg_token: bool = False,
frozen_stages: int = -1,
output_cls_token: bool = True,
use_abs_pos_emb: bool = False,
use_rel_pos_bias: bool = False,
use_shared_rel_pos_bias: bool = True,
layer_scale_init_value: int = 0.1,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(padding=0),
layer_cfgs: dict = dict(),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
avg_token=avg_token,
frozen_stages=frozen_stages,
output_cls_token=output_cls_token,
use_abs_pos_emb=use_abs_pos_emb,
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
use_rel_pos_bias=use_rel_pos_bias,
layer_scale_init_value=layer_scale_init_value,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
trunc_normal_(self.cls_token, std=0.02)
trunc_normal_(self.mask_token, std=0.02)
self.rescale_init_weight()
def rescale_init_weight(self) -> None:
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.layers):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.ffn.layers[1].weight.data, layer_id + 1)
def forward(self, x: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor]:
"""The BEiT style forward function.
Args:
x (torch.Tensor): Input images, which is of shape (B x C x H x W).
mask (torch.Tensor): Mask for input, which is of shape
(B x patch_resolution[0] x patch_resolution[1]).
Returns:
Tuple[torch.Tensor]: Hidden features.
"""
x, patch_resolution = self.patch_embed(x)
# replace the masked visual tokens by mask_token
B, L, _ = x.shape
mask_token = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
x = x * (1. - w) + mask_token * w
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
self.shared_rel_pos_bias = self.rel_pos_bias().to(
mask.device) if self.rel_pos_bias is not None else None
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, rel_pos_bias=self.shared_rel_pos_bias)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)

View File

@ -0,0 +1,348 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Part of code is modified from BEiT
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
import math
from collections import OrderedDict
from functools import partial
from typing import List, Optional, Union
import attr
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models import VisionTransformer
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
@attr.s(eq=False)
class Conv2d(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)
use_float16: bool = attr.ib(default=True)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw),
dtype=torch.float32,
device=self.device,
requires_grad=self.requires_grad)
w.normal_(std=1 / math.sqrt(self.n_in * self.kw**2))
b = torch.zeros((self.n_out, ),
dtype=torch.float32,
device=self.device,
requires_grad=self.requires_grad)
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_float16 and 'cuda' in self.w.device.type:
if x.dtype != torch.float16:
x = x.half()
w, b = self.w.half(), self.b.half()
else:
if x.dtype != torch.float32:
x = x.float()
w, b = self.w, self.b
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
@attr.s(eq=False, repr=False)
class EncoderBlock(nn.Module):
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1)
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 == 0)
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1)
device: torch.device = attr.ib(default=None)
requires_grad: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
super().__init__()
self.n_hid = self.n_out // 4
self.post_gain = 1 / (self.n_layers**2)
make_conv = partial(
Conv2d, device=self.device, requires_grad=self.requires_grad)
self.id_path = make_conv(
self.n_in, self.n_out,
1) if self.n_in != self.n_out else nn.Identity()
self.res_path = nn.Sequential(
OrderedDict([
('relu_1', nn.ReLU()),
('conv_1', make_conv(self.n_in, self.n_hid, 3)),
('relu_2', nn.ReLU()),
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
('relu_3', nn.ReLU()),
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
('relu_4', nn.ReLU()),
('conv_4', make_conv(self.n_hid, self.n_out, 1)),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.id_path(x) + self.post_gain * self.res_path(x)
@attr.s(eq=False, repr=False)
@MODELS.register_module(name='DALL-E')
class Encoder(BaseModule):
group_count: int = 4
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64)
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1)
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1)
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512)
device: torch.device = attr.ib(default=torch.device('cpu'))
requires_grad: bool = attr.ib(default=False)
use_mixed_precision: bool = attr.ib(default=True)
init_cfg: Optional[Union[dict, List[dict]]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
super().__init__(init_cfg=self.init_cfg)
blk_range = range(self.n_blk_per_group)
n_layers = self.group_count * self.n_blk_per_group
make_conv = partial(
Conv2d, device=self.device, requires_grad=self.requires_grad)
make_blk = partial(
EncoderBlock,
n_layers=n_layers,
device=self.device,
requires_grad=self.requires_grad)
self.blocks = nn.Sequential(
OrderedDict([
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)),
('group_1',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(1 * self.n_hid, 1 * self.n_hid))
for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_2',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
1 * self.n_hid if i == 0 else 2 * self.n_hid,
2 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_3',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
2 * self.n_hid if i == 0 else 4 * self.n_hid,
4 * self.n_hid)) for i in blk_range],
('pool', nn.MaxPool2d(kernel_size=2)),
]))),
('group_4',
nn.Sequential(
OrderedDict([
*[(f'block_{i + 1}',
make_blk(
4 * self.n_hid if i == 0 else 8 * self.n_hid,
8 * self.n_hid)) for i in blk_range],
]))),
('output',
nn.Sequential(
OrderedDict([
('relu', nn.ReLU()),
('conv',
make_conv(
8 * self.n_hid,
self.vocab_size,
1,
use_float16=False)),
]))),
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.float()
if len(x.shape) != 4:
raise ValueError(f'input shape {x.shape} is not 4d')
if x.shape[1] != self.input_channels:
raise ValueError(f'input has {x.shape[1]} channels but model \
built for {self.input_channels}')
if x.dtype != torch.float32:
raise ValueError('input must have dtype torch.float32')
return self.blocks(x)
@MODELS.register_module()
class CAEViT(VisionTransformer):
"""Vision Transformer for CAE pre-training.
Rewritten version of: `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
bias (bool | str): The option to add leanable bias for q, k, v. If bias
is True, it will add leanable bias. If bias is 'qv_bias', it will
only add leanable bias for q, v. If bias is False, it will not add
bias for q, k, v. Default to 'qv_bias'.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float, optional): The init value of gamma in
BEiTTransformerEncoderLayer.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: str = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: int = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
bias: bool = 'qv_bias',
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
interpolate_mode: str = 'bicubic',
layer_scale_init_value: float = None,
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: dict = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.pos_embed.requires_grad = False
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
dpr = np.linspace(0, drop_path_rate, self.num_layers)
# Replace original TransformerEncoderLayer with
# BEiTTransformerEncoderLayer
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
layer_scale_init_value=layer_scale_init_value,
window_size=None,
# setting `use_rel_pos_bias` to False ignores the `window_size`
use_rel_pos_bias=False,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
bias=bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(BEiTTransformerEncoderLayer(**_layer_cfg))
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# initialize position embedding in backbone
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=True)
self.pos_embed.data.copy_(pos_embed.float())
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m) -> None:
"""Initialize the weights."""
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate features for masked images.
This function generates mask images and get the hidden features for
visible patches.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (torch.Tensor): Mask for input, which is of shape B x L.
Returns:
torch.Tensor: hidden features.
"""
x, _ = self.patch_embed(img)
batch_size, _, dim = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# NOTE: unmasked embeddings
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,
dim)
pos_embed_unmasked = pos_embed[:,
1:][~mask].reshape(batch_size, -1, dim)
pos_embed_unmasked = torch.cat((pos_embed[:, :1], pos_embed_unmasked),
dim=1)
x_unmasked = x_unmasked + pos_embed_unmasked
x_unmasked = self.drop_after_pos(x_unmasked)
for i, layer in enumerate(self.layers):
x_unmasked = layer(x=x_unmasked, rel_pos_bias=None)
if i == len(self.layers) - 1 and self.final_norm:
x_unmasked = self.norm1(x_unmasked)
return x_unmasked

View File

@ -0,0 +1,178 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from mmpretrain.models import VisionTransformer
from mmpretrain.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
@MODELS.register_module()
class MAEViT(VisionTransformer):
"""Vision Transformer for MAE pre-training.
A PyTorch implement of: `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
This module implements the patch masking in MAE and initialize the
position embedding with sine-cosine position embedding.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
mask_ratio: float = 0.75,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
# position embedding is not learnable during pretraining
self.pos_embed.requires_grad = False
self.mask_ratio = mask_ratio
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().init_weights()
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=True)
self.pos_embed.data.copy_(pos_embed.float())
w = self.patch_embed.projection.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=.02)
def random_masking(
self,
x: torch.Tensor,
mask_ratio: float = 0.75
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate the mask for MAE Pre-training.
Args:
x (torch.Tensor): Image with data augmentation applied, which is
of shape B x L x C.
mask_ratio (float): The mask ratio of total patches.
Defaults to 0.75.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
masked image, mask and the ids to restore original image.
- x_masked (torch.Tensor): masked image.
- mask (torch.Tensor): mask used to mask image.
- ids_restore (torch.Tensor): ids to restore original image.
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
This function generates mask and masks some patches randomly and get
the hidden features for visible patches.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Hidden features, mask and the ids to restore original image.
- x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- mask (torch.Tensor): mask used to mask image.
- ids_restore (torch.Tensor): ids to restore original image.
"""
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for _, layer in enumerate(self.layers):
x = layer(x)
# Use final norm
x = self.norm1(x)
return (x, mask, ids_restore)

View File

@ -0,0 +1,275 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Sequence, Union
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmpretrain.models import VisionTransformer
from mmpretrain.registry import MODELS
@MODELS.register_module()
class HOGGenerator(BaseModule):
"""Generate HOG feature for images.
This module is used in MaskFeat to generate HOG feature. The code is
modified from file `slowfast/models/operators.py
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>`_.
Here is the link of `HOG wikipedia
<https://en.wikipedia.org/wiki/Histogram_of_oriented_gradients>`_.
Args:
nbins (int): Number of bin. Defaults to 9.
pool (float): Number of cell. Defaults to 8.
gaussian_window (int): Size of gaussian kernel. Defaults to 16.
"""
def __init__(self,
nbins: int = 9,
pool: int = 8,
gaussian_window: int = 16) -> None:
super().__init__()
self.nbins = nbins
self.pool = pool
self.pi = math.pi
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1).contiguous()
weight_y = weight_x.transpose(2, 3).contiguous()
self.register_buffer('weight_x', weight_x)
self.register_buffer('weight_y', weight_y)
self.gaussian_window = gaussian_window
if gaussian_window:
gaussian_kernel = self.get_gaussian_kernel(gaussian_window,
gaussian_window // 2)
self.register_buffer('gaussian_kernel', gaussian_kernel)
def get_gaussian_kernel(self, kernlen: int, std: int) -> torch.Tensor:
"""Returns a 2D Gaussian kernel array."""
def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor:
n = torch.arange(0, kernlen).float()
n -= n.mean()
n /= std
w = torch.exp(-0.5 * n**2)
return w
kernel_1d = _gaussian_fn(kernlen, std)
kernel_2d = kernel_1d[:, None] * kernel_1d[None, :]
return kernel_2d / kernel_2d.sum()
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
"""Reshape HOG Features for output."""
hog_feat = hog_feat.flatten(1, 2)
self.unfold_size = hog_feat.shape[-1] // 14
hog_feat = hog_feat.permute(0, 2, 3, 1)
hog_feat = hog_feat.unfold(1, self.unfold_size,
self.unfold_size).unfold(
2, self.unfold_size, self.unfold_size)
hog_feat = hog_feat.flatten(1, 2).flatten(2)
return hog_feat
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Generate hog feature for each batch images.
Args:
x (torch.Tensor): Input images of shape (N, 3, H, W).
Returns:
torch.Tensor: Hog features.
"""
# input is RGB image with shape [B 3 H W]
self.h, self.w = x.size(-2), x.size(-1)
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
gx_rgb = F.conv2d(
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
gy_rgb = F.conv2d(
x, self.weight_y, bias=None, stride=1, padding=0, groups=3)
norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)
phase = torch.atan2(gx_rgb, gy_rgb)
phase = phase / self.pi * self.nbins # [-9, 9]
b, c, h, w = norm_rgb.shape
out = torch.zeros((b, c, self.nbins, h, w),
dtype=torch.float,
device=x.device)
phase = phase.view(b, c, 1, h, w)
norm_rgb = norm_rgb.view(b, c, 1, h, w)
if self.gaussian_window:
if h != self.gaussian_window:
assert h % self.gaussian_window == 0, 'h {} gw {}'.format(
h, self.gaussian_window)
repeat_rate = h // self.gaussian_window
temp_gaussian_kernel = self.gaussian_kernel.repeat(
[repeat_rate, repeat_rate])
else:
temp_gaussian_kernel = self.gaussian_kernel
norm_rgb *= temp_gaussian_kernel
out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)
out = out.unfold(3, self.pool, self.pool)
out = out.unfold(4, self.pool, self.pool)
out = out.sum(dim=[-1, -2])
self.out = F.normalize(out, p=2, dim=2)
return self._reshape(self.out)
def generate_hog_image(self, hog_out: torch.Tensor) -> np.ndarray:
"""Generate HOG image according to HOG features."""
assert hog_out.size(0) == 1 and hog_out.size(1) == 3, \
'Check the input batch size and the channcel number, only support'\
'"batch_size = 1".'
hog_image = np.zeros([self.h, self.w])
cell_gradient = np.array(hog_out.mean(dim=1).squeeze().detach().cpu())
cell_width = self.pool / 2
max_mag = np.array(cell_gradient).max()
angle_gap = 360 / self.nbins
for x in range(cell_gradient.shape[1]):
for y in range(cell_gradient.shape[2]):
cell_grad = cell_gradient[:, x, y]
cell_grad /= max_mag
angle = 0
for magnitude in cell_grad:
angle_radian = math.radians(angle)
x1 = int(x * self.pool +
magnitude * cell_width * math.cos(angle_radian))
y1 = int(y * self.pool +
magnitude * cell_width * math.sin(angle_radian))
x2 = int(x * self.pool -
magnitude * cell_width * math.cos(angle_radian))
y2 = int(y * self.pool -
magnitude * cell_width * math.sin(angle_radian))
magnitude = 0 if magnitude < 0 else magnitude
cv2.line(hog_image, (y1, x1), (y2, x2),
int(255 * math.sqrt(magnitude)))
angle += angle_gap
return hog_image
@MODELS.register_module()
class MaskFeatViT(VisionTransformer):
"""Vision Transformer for MaskFeat pre-training.
A PyTorch implement of: `Masked Feature Prediction for Self-Supervised
Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
out_indices: Union[Sequence, int] = -1,
drop_rate: float = 0,
drop_path_rate: float = 0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.mask_token = nn.parameter.Parameter(
torch.zeros(1, 1, self.embed_dims), requires_grad=True)
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
def init_weights(self) -> None:
"""Initialize position embedding, mask token and cls token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
nn.init.trunc_normal_(self.cls_token, std=.02)
nn.init.trunc_normal_(self.mask_token, std=.02)
nn.init.trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m: torch.nn.Module) -> None:
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate features for masked images.
Args:
x (torch.Tensor): Input images.
mask (torch.Tensor): Input masks.
Returns:
torch.Tensor: Features with cls_tokens.
"""
B = x.shape[0]
x = self.patch_embed(x)[0]
# masking: length -> length * mask_ratio
B, L, _ = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
mask = mask.flatten(1).unsqueeze(-1)
x = x * (1 - mask.int()) + mask_tokens * mask
# append cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
return x

View File

@ -0,0 +1,149 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmengine.model import BaseModule
from mmengine.runner.checkpoint import _load_checkpoint
from mmpretrain.registry import MODELS
from ..utils import build_clip_model
from .mae import MAEViT
@MODELS.register_module()
class CLIPGenerator(BaseModule):
"""Get the features and attention from the last layer of CLIP.
This module is used to generate target features in masked image modeling.
Args:
tokenizer_path (str): The path of the checkpoint of CLIP.
"""
def __init__(self, tokenizer_path: str) -> None:
super().__init__()
self.tokenizer = build_clip_model(
_load_checkpoint(tokenizer_path), False)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the features and attention from the last layer of CLIP.
Args:
x (torch.Tensor): The input image, which is of shape (N, 3, H, W).
Returns:
Tuple[torch.Tensor, torch.Tensor]:
The features and attention from the last layer of CLIP,
which are of shape (N, L, C) and (N, L, L), respectively.
"""
# use the visual branch of CLIP to get the features
clip_features = self.tokenizer.encode_image(x)
return clip_features
@MODELS.register_module()
class MILANViT(MAEViT):
"""Vision Transformer for MILAN pre-training.
Implementation of the encoder for `MILAN: Masked Image Pretraining on
Language Assisted Representation <https://arxiv.org/abs/2208.06049>`_.
This module inherits from MAEViT and only overrides the forward function
and replace random masking with attention masking.
"""
def attention_masking(
self, x: torch.Tensor, mask_ratio: float, importance: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate attention mask for MILAN.
This is what is different from MAEViT, which uses random masking.
Attention masking generates attention mask for MILAN, according to
importance. The higher the importance, the more likely the patch is
kept.
Args:
x (torch.Tensor): Input images, which is of shape B x L x C.
mask_ratio (float): The ratio of patches to be masked.
importance (torch.Tensor): Importance of each patch, which is of
shape B x L.
Returns:
Tuple[torch.Tensor, ...]:
- ``x_masked``: masked image
- ``ids_restore``: the ids to restore original image
- ``ids_keep``: ids of the kept patches
- ``ids_dump``: ids of the removed patches
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = importance.to(x.device) # large is keep, small is remove
# sort noise for each sample
ids_shuffle = torch.multinomial(noise, L, replacement=False)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
ids_dump = ids_shuffle[:, len_keep:]
x_masked = torch.gather(
x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, ids_restore, ids_keep, ids_dump
def forward(
self, x: torch.Tensor, importance: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
This function generates mask and masks some patches randomly and get
the hidden features for visible patches. The mask is generated by
importance. The higher the importance, the more likely the patch is
kept. The importance is calculated by CLIP. The higher the CLIP score,
the more likely the patch is kept. The CLIP score is calculated by
by cross attention between the class token and all other tokens from
the last layer.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
importance (torch.Tensor): Importance of each patch, which is of
shape B x L.
Returns:
Tuple[torch.Tensor, ...]:
masked image, the ids to restore original image, ids of the
kept patches, ids of the removed patches.
- x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ids_restore (torch.Tensor): ids to restore original image.
- ids_keep (torch.Tensor): ids of the kept patches.
- ids_dump (torch.Tensor): ids of the removed patches.
"""
B = x.shape[0]
x = self.patch_embed(x)[0]
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, ids_restore, ids_keep, ids_dump = self.attention_masking(
x, self.mask_ratio, importance)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for _, layer in enumerate(self.layers):
x = layer(x)
# Use final norm
x = self.norm1(x)
return x, ids_restore, ids_keep, ids_dump

View File

@ -0,0 +1,200 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import List, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
from mmpretrain.models.backbones import MixMIMTransformer
from mmpretrain.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
@MODELS.register_module()
class MixMIMPretrainTransformer(MixMIMTransformer):
"""MixMIM backbone for MixMIM pre-training.
A PyTorch implement of : ` MixMIM: Mixed and Masked Image
Modeling for Efficient Visual Representation Learning
<https://arxiv.org/abs/2205.13137>`_
Args:
arch (str | dict): MixMIM architecture. If use string,
choose from 'base','large' and 'huge'.
If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
Defaults to 'base'.
mlp_ratio (int): The mlp ratio in FFN. Defaults to 4.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to mlp_ratio
the most common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
window_size (list): The height and width of the window.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
attn_drop_rate (float): Attention drop rate. Defaults to 0.
use_checkpoint (bool): Whether use the checkpoint to
reduce GPU memory cost
range_mask_ratio (float): The range of mask ratio.
Defaults to 0.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'base',
mlp_ratio: float = 4,
img_size: int = 224,
patch_size: int = 4,
in_channels: int = 3,
window_size: List = [14, 14, 14, 7],
qkv_bias: bool = True,
patch_cfg: dict = dict(),
norm_cfg: dict = dict(type='LN'),
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
attn_drop_rate: float = 0.0,
use_checkpoint: bool = False,
range_mask_ratio: float = 0.0,
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
mlp_ratio=mlp_ratio,
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
window_size=window_size,
qkv_bias=qkv_bias,
patch_cfg=patch_cfg,
norm_cfg=norm_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
attn_drop_rate=attn_drop_rate,
use_checkpoint=use_checkpoint,
init_cfg=init_cfg)
self.range_mask_ratio = range_mask_ratio
def init_weights(self):
"""Initialize position embedding, patch embedding."""
super(MixMIMTransformer, self).init_weights()
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.absolute_pos_embed.shape[-1],
cls_token=False)
self.absolute_pos_embed.data.copy_(pos_embed.float())
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def random_masking(self, x: torch.Tensor, mask_ratio: float = 0.5):
"""Generate the mask for MixMIM Pretraining.
Args:
x (torch.Tensor): Image with data augmentation applied, which is
of shape B x L x C.
mask_ratio (float): The mask ratio of total patches.
Defaults to 0.5.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- mask_s1 (torch.Tensor): mask with stride of
self.encoder_stride // 8.
- mask_s2 (torch.Tensor): mask with stride of
self.encoder_stride // 4.
- mask_s3 (torch.Tensor): mask with stride of
self.encoder_stride // 2.
- mask (torch.Tensor): mask with stride of
self.encoder_stride.
"""
B, C, H, W = x.shape
out_H = H // self.encoder_stride
out_W = W // self.encoder_stride
s3_H, s3_W = out_H * 2, out_W * 2
s2_H, s2_W = out_H * 4, out_W * 4
s1_H, s1_W = out_H * 8, out_W * 8
seq_l = out_H * out_W
# use a shared mask for a batch images
mask = torch.zeros([1, 1, seq_l], device=x.device)
mask_ratio = mask_ratio + random.uniform(0.0, self.range_mask_ratio)
noise = torch.rand(1, 1, seq_l, device=x.device) # noise in [0, 1]
# ascend: small is keep, large is removed
mask_idx = torch.argsort(noise, dim=2)[:, :, :int(seq_l * mask_ratio)]
mask.scatter_(2, mask_idx, 1)
mask = mask.reshape(1, 1, out_H, out_W)
mask_s1 = F.interpolate(mask, size=(s1_H, s1_W), mode='nearest')
mask_s2 = F.interpolate(mask, size=(s2_H, s2_W), mode='nearest')
mask_s3 = F.interpolate(mask, size=(s3_H, s3_W), mode='nearest')
mask = mask.reshape(1, out_H * out_W, 1).contiguous()
mask_s1 = mask_s1.reshape(1, s1_H * s1_W, 1).contiguous()
mask_s2 = mask_s2.reshape(1, s2_H * s2_W, 1).contiguous()
mask_s3 = mask_s3.reshape(1, s3_H * s3_W, 1).contiguous()
return mask_s1, mask_s2, mask_s3, mask
def forward(self, x: torch.Tensor, mask_ratio=0.5):
"""Generate features for masked images.
This function generates mask and masks some patches randomly and get
the hidden features for visible patches.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- x (torch.Tensor): hidden features, which is of shape
B x L x C.
- mask_s4 (torch.Tensor): the mask tensor for the last layer.
"""
mask_s1, mask_s2, mask_s3, mask_s4 = self.random_masking(x, mask_ratio)
x, _ = self.patch_embed(x)
x = x * (1. - mask_s1) + x.flip(0) * mask_s1
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
for idx, layer in enumerate(self.layers):
if idx == 0:
x = layer(x, attn_mask=mask_s1)
elif idx == 1:
x = layer(x, attn_mask=mask_s2)
elif idx == 2:
x = layer(x, attn_mask=mask_s3)
elif idx == 3:
x = layer(x, attn_mask=mask_s4)
x = self.norm(x)
return x, mask_s4

View File

@ -0,0 +1,134 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import reduce
from operator import mul
from typing import List, Optional, Union
import torch.nn as nn
from mmcv.cnn.bricks.transformer import PatchEmbed
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones import VisionTransformer
from mmpretrain.models.utils import (build_2d_sincos_position_embedding,
to_2tuple)
from mmpretrain.registry import MODELS
@MODELS.register_module()
class MoCoV3ViT(VisionTransformer):
"""Vision Transformer for MoCoV3 pre-training.
A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for
Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_.
Part of the code is modified from:
`<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_.
Args:
stop_grad_conv1 (bool): whether to stop the gradient of
convolution layer in `PatchEmbed`. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
stop_grad_conv1: bool = False,
frozen_stages: int = -1,
norm_eval: bool = False,
init_cfg: Optional[Union[dict, List[dict]]] = None,
**kwargs) -> None:
# add MoCoV3 ViT-small arch
self.arch_zoo.update(
dict.fromkeys(
['mocov3-s', 'mocov3-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 1536,
}))
super().__init__(init_cfg=init_cfg, **kwargs)
self.patch_size = kwargs['patch_size']
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
self.init_cfg = init_cfg
if isinstance(self.patch_embed, PatchEmbed):
if stop_grad_conv1:
self.patch_embed.projection.weight.requires_grad = False
self.patch_embed.projection.bias.requires_grad = False
self._freeze_stages()
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding, qkv layers and cls
token."""
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Use fixed 2D sin-cos position embedding
pos_emb = build_2d_sincos_position_embedding(
patches_resolution=self.patch_resolution,
embed_dims=self.embed_dims,
cls_token=True)
self.pos_embed.data.copy_(pos_emb)
self.pos_embed.requires_grad = False
# xavier_uniform initialization for PatchEmbed
if isinstance(self.patch_embed, PatchEmbed):
val = math.sqrt(
6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) +
self.embed_dims))
nn.init.uniform_(self.patch_embed.projection.weight, -val, val)
nn.init.zeros_(self.patch_embed.projection.bias)
# initialization for linear layers
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(
6. /
float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
else:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.cls_token, std=1e-6)
def _freeze_stages(self) -> None:
"""Freeze patch_embed layer, some parameters and stages."""
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
self.cls_token.requires_grad = False
self.pos_embed.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
if i == (self.num_layers) and self.final_norm:
for param in getattr(self, 'norm1').parameters():
param.requires_grad = False
def train(self, mode: bool = True) -> None:
super().train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -0,0 +1,154 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models import SwinTransformer
from mmpretrain.registry import MODELS
@MODELS.register_module()
class SimMIMSwinTransformer(SwinTransformer):
"""Swin Transformer for SimMIM pre-training.
Args:
Args:
arch (str | dict): Swin Transformer architecture
Defaults to 'T'.
img_size (int | tuple): The size of input image.
Defaults to 224.
in_channels (int): The num of input channels.
Defaults to 3.
drop_rate (float): Dropout rate after embedding.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate.
Defaults to 0.1.
out_indices (tuple): Layers to be outputted. Defaults to (3, ).
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer at end
of backone. Defaults to dict(type='LN')
stage_cfgs (Sequence | dict): Extra config dict for each
stage. Defaults to empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to empty dict.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'T',
img_size: Union[Tuple[int, int], int] = 224,
in_channels: int = 3,
drop_rate: float = 0.,
drop_path_rate: float = 0.1,
out_indices: tuple = (3, ),
use_abs_pos_embed: bool = False,
with_cp: bool = False,
frozen_stages: bool = -1,
norm_eval: bool = False,
norm_cfg: dict = dict(type='LN'),
stage_cfgs: Union[Sequence, dict] = dict(),
patch_cfg: dict = dict(),
pad_small_map: bool = False,
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
in_channels=in_channels,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
out_indices=out_indices,
use_abs_pos_embed=use_abs_pos_embed,
with_cp=with_cp,
frozen_stages=frozen_stages,
norm_eval=norm_eval,
norm_cfg=norm_cfg,
stage_cfgs=stage_cfgs,
patch_cfg=patch_cfg,
pad_small_map=pad_small_map,
init_cfg=init_cfg)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def init_weights(self) -> None:
"""Initialize weights."""
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
trunc_normal_(self.mask_token, mean=0, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights."""
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor,
mask: torch.Tensor) -> Sequence[torch.Tensor]:
"""Generate features for masked images.
This function generates mask images and get the hidden features for
them.
Args:
x (torch.Tensor): Input images.
mask (torch.Tensor): Masks used to construct masked images.
Returns:
tuple: A tuple containing features from multi-stages.
"""
x, hw_shape = self.patch_embed(x)
assert mask is not None
B, L, _ = x.shape
mask_token = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
x = x * (1. - w) + mask_token * w
if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return tuple(outs)

View File

@ -5,6 +5,7 @@ from .attention import (BEiTAttention, ChannelMultiheadAttention,
ShiftWindowMSA, WindowMSA, WindowMSAV2)
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
from .channel_shuffle import channel_shuffle
from .clip_generator_helper import build_clip_model
from .data_preprocessor import ClsDataPreprocessor
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
resize_relative_position_bias_table)
@ -17,6 +18,7 @@ from .position_encoding import (ConditionalPositionEncoding,
PositionEncodingFourier,
build_2d_sincos_position_embedding)
from .se_layer import SELayer
from .vector_quantizer import NormEMAVectorQuantizer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
@ -28,5 +30,6 @@ __all__ = [
'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention',
'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d',
'build_norm_layer', 'CrossMultiheadAttention',
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention'
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention',
'NormEMAVectorQuantizer', 'build_clip_model'
]

View File

@ -567,7 +567,7 @@ class BEiTAttention(BaseModule):
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): The height and width of the window.
window_size (tuple[int, int]): The height and width of the window.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
bias (str): The option to add leanable bias for q, k, v. If bias is
@ -606,6 +606,10 @@ class BEiTAttention(BaseModule):
self._init_qv_bias()
qkv_bias = False
if window_size is None:
assert not use_rel_pos_bias
else:
assert isinstance(window_size, tuple)
self.window_size = window_size
self.use_rel_pos_bias = use_rel_pos_bias
self._init_rel_pos_embedding()

View File

@ -0,0 +1,391 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zejiangh/MILAN
from collections import OrderedDict
from typing import Optional, Tuple, Union
import numpy as np
import torch
from mmengine.logging import MMLogger
from torch import nn
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
"""A faster version of GELU."""
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
"""Residual Attention Block (RAB).
This module implements the same function as the MultiheadAttention in
MMClassification, but with a different interface, which is mainly used
in CLIP.
Args:
d_model (int): The feature dimension.
n_head (int): The number of attention heads.
attn_mask (torch.Tensor, optional): The attention mask.
Defaults to None.
"""
def __init__(self,
d_model: int,
n_head: int,
attn_mask: Optional[torch.Tensor] = None,
return_attention: bool = False) -> None:
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.return_attention = return_attention
def attention(self, x: torch.Tensor) -> torch.Tensor:
"""Attention function."""
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
if self.return_attention:
return self.attn(
x,
x,
x,
need_weights=self.return_attention,
attn_mask=self.attn_mask)
else:
return self.attn(
x,
x,
x,
need_weights=self.return_attention,
attn_mask=self.attn_mask)[0]
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward function."""
if self.return_attention:
x_, attention = self.attention(self.ln_1(x))
x = x + x_
x = x + self.mlp(self.ln_2(x))
return x, attention
else:
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
"""Transformer.
Both visual and text branches use this transformer.
Args:
width (int): The feature dimension.
layers (int): The number of layers.
heads (int): The number of attention heads.
attn_mask (torch.Tensor, optional): The attention mask.
"""
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList()
for _ in range(layers - 1):
self.resblocks.append(
ResidualAttentionBlock(width, heads, attn_mask))
self.resblocks.append(
ResidualAttentionBlock(
width, heads, attn_mask, return_attention=True))
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward function."""
z = []
for idx, blk in enumerate(self.resblocks):
if idx < self.layers - 1:
x = blk(x)
z.append(x.permute(1, 0, 2))
else:
x, attention = blk(x)
z.append(x.permute(1, 0, 2))
return x, attention, z
class VisionTransformer(nn.Module):
"""Vision Transformer for CLIP.
Args:
input_resolution (int): The image size.
patch_size (int): The patch size.
width (int): The feature dimension.
layers (int): The number of layers.
heads (int): The number of attention heads.
out_dim (int): The output dimension.
fineturn (bool): Whether to fineturn the model.
average_target (bool): Whether to average the target.
"""
def __init__(self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
output_dim: int,
finetune=False,
average_targets: int = 1) -> None:
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.finetune = finetune
if finetune is False:
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.average_targets = average_targets
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function."""
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x, attention, z = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
if self.proj is not None:
x = x @ self.proj
return x, attention
class CLIP(nn.Module):
"""CLIP.
Args:
embed_dim (int): The embedding dimension.
image_resolution (int): The image size.
vision_layers (int): The number of layers in the vision transformer.
vision_width (int): The feature dimension in the vision transformer.
vision_patch_size (int): The patch size in the vision transformer.
context_length (int): The context length.
vocab_size (int): The vocabulary size.
transformer_width (int): The feature dimension in the text transformer.
transformer_heads (int): The number of attention heads in the
text transformer.
transformer_layers (int): The number of layers in the text transformer.
fineturn (bool): Whether to fineturn the model.
average_target (bool): Whether to average the target.
"""
def __init__(
self,
embed_dim: int,
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
finetune: bool = False,
average_targets: int = 1,
) -> None:
super().__init__()
self.context_length = context_length
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
finetune=finetune,
average_targets=average_targets,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self) -> None:
"""Initialize the parameters.
The pretrained weight will override the initialized parameters by this
function.
"""
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self) -> torch.Tensor:
"""Build the attention mask."""
# lazily create causal attention mask, with full attention between the
# vision tokens pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self) -> torch.dtype:
"""Get the dtype."""
return self.visual.conv1.weight.dtype
def encode_image(self,
image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode the image.
Get the feature and attention mask from the last layer of the visual
branch of CLIP.
Args:
image (torch.Tensor): The image tensor with shape NCHW.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The feature and attention mask.
"""
return self.visual(image.type(self.dtype))
def build_clip_model(state_dict: dict,
finetune: bool = False,
average_targets: int = 1) -> nn.Module:
"""Build the CLIP model.
Args:
state_dict (dict): The pretrained state dict.
finetune (bool): Whether to fineturn the model.
average_targets (bool): Whether to average the target.
Returns:
nn.Module: The CLIP model.
"""
vit = 'visual.proj' in state_dict
if vit:
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))
model = CLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
finetune,
average_targets,
)
for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict:
del state_dict[key]
msg = model.load_state_dict(state_dict, strict=False)
MMLogger.get_current_instance().info(f'Load CLIP model: {msg}')
return model.eval()

View File

@ -81,7 +81,7 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
else:
self._enable_normalize = False
if batch_augments is not None:
if batch_augments:
self.batch_augments = RandomBatchAugment(**batch_augments)
if not self.to_onehot:
from mmengine.logging import MMLogger

View File

@ -0,0 +1,232 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) 2022 Microsoft
# Modified from
# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from mmengine.dist import all_reduce
def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
decay: torch.Tensor) -> None:
"""Update moving average."""
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
decay: torch.Tensor) -> None:
"""Update moving average with norm data."""
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1))
def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
"""Sample vectors according to the given number."""
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num, ), device=device)
return samples[indices]
def kmeans(samples: torch.Tensor,
num_clusters: int,
num_iters: int = 10,
use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run k-means algorithm."""
dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
if use_cosine_sim:
dists = samples @ means.t()
else:
diffs = rearrange(samples, 'n d -> n () d') \
- rearrange(means, 'c d -> () c d')
dists = -(diffs**2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
if use_cosine_sim:
new_means = F.normalize(new_means, p=2, dim=-1)
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
class EmbeddingEMA(nn.Module):
"""The codebook of embedding vectors.
Args:
num_tokens (int): Number of embedding vectors in the codebook.
codebook_dim (int) : The dimension of embedding vectors in the
codebook.
kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
codebook_init_path (str): The initialization checkpoint for codebook.
Defaults to None.
"""
def __init__(self,
num_tokens: int,
codebook_dim: int,
kmeans_init: bool = True,
codebook_init_path: Optional[str] = None):
super().__init__()
self.num_tokens = num_tokens
self.codebook_dim = codebook_dim
if codebook_init_path is None:
if not kmeans_init:
weight = torch.randn(num_tokens, codebook_dim)
weight = F.normalize(weight, p=2, dim=-1)
else:
weight = torch.zeros(num_tokens, codebook_dim)
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
else:
print(f'load init codebook weight from {codebook_init_path}')
codebook_ckpt_weight = torch.load(
codebook_init_path, map_location='cpu')
weight = codebook_ckpt_weight.clone()
self.register_buffer('initted', torch.Tensor([True]))
self.weight = nn.Parameter(weight, requires_grad=False)
self.update = True
@torch.jit.ignore
def init_embed_(self, data: torch.Tensor) -> None:
"""Initialize embedding vectors of codebook."""
if self.initted:
return
print('Performing K-means init for codebook')
embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
self.weight.data.copy_(embed)
self.initted.data.copy_(torch.Tensor([True]))
def forward(self, embed_id: torch.Tensor) -> torch.Tensor:
"""Get embedding vectors."""
return F.embedding(embed_id, self.weight)
class NormEMAVectorQuantizer(nn.Module):
"""Normed EMA vector quantizer module.
Args:
num_embed (int): Number of embedding vectors in the codebook. Defaults
to 8192.
embed_dims (int) : The dimension of embedding vectors in the codebook.
Defaults to 32.
beta (float): The mutiplier for VectorQuantizer embedding loss.
Defaults to 1.
decay (float): The decay parameter of EMA. Defaults to 0.99.
statistic_code_usage (bool): Whether to use cluster_size to record
statistic. Defaults to True.
kmeans_init (bool): Whether to use k-means to initialize the
VectorQuantizer. Defaults to True.
codebook_init_path (str): The initialization checkpoint for codebook.
Defaults to None.
"""
def __init__(self,
num_embed: int,
embed_dims: int,
beta: float,
decay: float = 0.99,
statistic_code_usage: bool = True,
kmeans_init: bool = True,
codebook_init_path: Optional[str] = None) -> None:
super().__init__()
self.codebook_dim = embed_dims
self.num_tokens = num_embed
self.beta = beta
self.decay = decay
# learnable = True if orthogonal_reg_weight > 0 else False
self.embedding = EmbeddingEMA(
num_tokens=self.num_tokens,
codebook_dim=self.codebook_dim,
kmeans_init=kmeans_init,
codebook_init_path=codebook_init_path)
self.statistic_code_usage = statistic_code_usage
if statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(num_embed))
def reset_cluster_size(self, device):
if self.statistic_code_usage:
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
self.cluster_size = self.cluster_size.to(device)
def forward(self, z):
"""Forward function."""
# reshape z -> (batch, height, width, channel)
z = rearrange(z, 'b c h w -> b h w c')
z = F.normalize(z, p=2, dim=-1)
z_flattened = z.reshape(-1, self.codebook_dim)
self.embedding.init_embed_(z_flattened)
# 'n d -> d n'
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(encoding_indices).view(z.shape)
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
if not self.training:
with torch.no_grad():
cluster_size = encodings.sum(0)
all_reduce(cluster_size)
ema_inplace(self.cluster_size, cluster_size, self.decay)
if self.training and self.embedding.update:
# update cluster size with EMA
bins = encodings.sum(0)
all_reduce(bins)
ema_inplace(self.cluster_size, bins, self.decay)
zero_mask = (bins == 0)
bins = bins.masked_fill(zero_mask, 1.)
embed_sum = z_flattened.t() @ encodings
all_reduce(embed_sum)
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
embed_normalized = F.normalize(embed_normalized, p=2, dim=-1)
embed_normalized = torch.where(zero_mask[..., None],
self.embedding.weight,
embed_normalized)
# Update embedding vectors with EMA
norm_ema_inplace(self.embedding.weight, embed_normalized,
self.decay)
# compute loss for embedding
loss = self.beta * F.mse_loss(z_q.detach(), z)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w')
return z_q, loss, encoding_indices

View File

@ -4,7 +4,7 @@ from unittest import TestCase
import torch
from mmpretrain.models.backbones import BEiT
from mmpretrain.models.backbones import BEiTViT
class TestBEiT(TestCase):
@ -18,7 +18,7 @@ class TestBEiT(TestCase):
with self.assertRaisesRegex(AssertionError, 'not in default archs'):
cfg = deepcopy(self.cfg)
cfg['arch'] = 'unknown'
BEiT(**cfg)
BEiTViT(**cfg)
# Test invalid custom arch
with self.assertRaisesRegex(AssertionError, 'Custom arch needs'):
@ -28,7 +28,7 @@ class TestBEiT(TestCase):
'num_heads': 16,
'feedforward_channels': 4096
}
BEiT(**cfg)
BEiTViT(**cfg)
# Test custom arch
cfg = deepcopy(self.cfg)
@ -38,7 +38,7 @@ class TestBEiT(TestCase):
'num_heads': 16,
'feedforward_channels': 1024
}
model = BEiT(**cfg)
model = BEiTViT(**cfg)
self.assertEqual(model.embed_dims, 128)
self.assertEqual(model.num_layers, 24)
self.assertIsNone(model.pos_embed)
@ -51,21 +51,21 @@ class TestBEiT(TestCase):
cfg = deepcopy(self.cfg)
cfg['out_indices'] = {1: 1}
with self.assertRaisesRegex(AssertionError, "get <class 'dict'>"):
BEiT(**cfg)
BEiTViT(**cfg)
cfg['out_indices'] = [0, 13]
with self.assertRaisesRegex(AssertionError, 'Invalid out_indices 13'):
BEiT(**cfg)
BEiTViT(**cfg)
# Test pos_embed
cfg = deepcopy(self.cfg)
cfg['use_abs_pos_emb'] = True
model = BEiT(**cfg)
model = BEiTViT(**cfg)
self.assertEqual(model.pos_embed.shape, (1, 197, 768))
# Test model structure
cfg = deepcopy(self.cfg)
cfg['drop_path_rate'] = 0.1
model = BEiT(**cfg)
model = BEiTViT(**cfg)
self.assertEqual(len(model.layers), 12)
dpr_inc = 0.1 / (12 - 1)
dpr = 0
@ -85,7 +85,7 @@ class TestBEiT(TestCase):
# test with output_cls_token
cfg = deepcopy(self.cfg)
cfg['output_cls_token'] = True
model = BEiT(**cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
@ -95,7 +95,7 @@ class TestBEiT(TestCase):
# test without output_cls_token
cfg = deepcopy(self.cfg)
model = BEiT(**cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
@ -105,7 +105,7 @@ class TestBEiT(TestCase):
# test without average
cfg = deepcopy(self.cfg)
cfg['avg_token'] = False
model = BEiT(**cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 1)
@ -115,7 +115,7 @@ class TestBEiT(TestCase):
# Test forward with multi out indices
cfg = deepcopy(self.cfg)
cfg['out_indices'] = [-3, -2, -1]
model = BEiT(**cfg)
model = BEiTViT(**cfg)
outs = model(imgs)
self.assertIsInstance(outs, tuple)
self.assertEqual(len(outs), 3)

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import BEiTPretrainViT
backbone = dict(
arch='base',
patch_size=16,
drop_path_rate=0.1,
final_norm=True,
layer_scale_init_value=0.1,
)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_beit_pretrain_vit():
beit_backbone = BEiTPretrainViT(**backbone)
beit_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_mask = torch.zeros((2, 196))
fake_mask[:, 75:150] = 1
fake_outputs = beit_backbone(fake_inputs, fake_mask)
assert list(fake_outputs[0].shape) == [2, 197, 768]

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import CAEViT
backbone = dict(arch='b', patch_size=16, layer_scale_init_value=0.1)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_cae_vit():
cae_backbone = CAEViT(**backbone)
cae_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_mask = torch.zeros((2, 196)).bool()
fake_mask[:, 75:150] = 1
fake_outputs = cae_backbone(fake_inputs, fake_mask)
assert list(fake_outputs.shape) == [2, 122, 768]

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import MAEViT
backbone = dict(arch='b', patch_size=16, mask_ratio=0.75)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_mae_vit():
mae_backbone = MAEViT(**backbone)
mae_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_outputs = mae_backbone(fake_inputs)[0]
assert list(fake_outputs.shape) == [2, 50, 768]

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import MaskFeatViT
backbone = dict(arch='b', patch_size=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat_vit():
maskfeat_backbone = MaskFeatViT(**backbone)
maskfeat_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_mask = torch.randn((2, 14, 14))
fake_outputs = maskfeat_backbone(fake_inputs, fake_mask)
assert list(fake_outputs.shape) == [2, 197, 768]

View File

@ -0,0 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
from mmpretrain.models import MoCoV3ViT
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_vision_transformer():
vit = MoCoV3ViT(
arch='mocov3-small', patch_size=16, frozen_stages=12, norm_eval=True)
vit.init_weights()
vit.train()
for p in vit.parameters():
assert p.requires_grad is False

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import SimMIMSwinTransformer
backbone = dict(
arch='B', img_size=192, stage_cfgs=dict(block_cfgs=dict(window_size=6)))
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_cae_vit():
simmim_backbone = SimMIMSwinTransformer(**backbone)
simmim_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 192, 192))
fake_mask = torch.rand((2, 48, 48))
fake_outputs = simmim_backbone(fake_inputs, fake_mask)[0]
assert list(fake_outputs.shape) == [2, 1024, 6, 6]

View File

@ -0,0 +1,73 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest import TestCase
import pytest
import torch
from mmpretrain.models import VQKD, Encoder, HOGGenerator
class TestDALLE(TestCase):
@pytest.mark.skipif(
platform.system() == 'Windows', reason='Windows mem limit')
def test_dalle(self):
model = Encoder()
fake_inputs = torch.rand((2, 3, 112, 112))
fake_outputs = model(fake_inputs)
assert list(fake_outputs.shape) == [2, 8192, 14, 14]
class TestHOGGenerator(TestCase):
def test_hog_generator(self):
hog_generator = HOGGenerator()
fake_input = torch.randn((2, 3, 224, 224))
fake_output = hog_generator(fake_input)
assert list(fake_output.shape) == [2, 196, 108]
fake_hog_out = hog_generator.out[0].unsqueeze(0)
fake_hog_img = hog_generator.generate_hog_image(fake_hog_out)
assert fake_hog_img.shape == (224, 224)
with pytest.raises(AssertionError):
fake_hog_img = hog_generator.generate_hog_image(
hog_generator.out[0])
class TestVQKD(TestCase):
ENCODER_CFG = dict(
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
avg_token=False,
frozen_stages=-1,
output_cls_token=False,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
layer_scale_init_value=0.,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None)
@pytest.mark.skipif(
platform.system() == 'Windows', reason='Windows mem limit')
def test_vqkd(self):
model = VQKD(encoder_config=self.ENCODER_CFG)
fake_inputs = torch.rand((2, 3, 224, 224))
fake_outputs = model(fake_inputs)
assert list(fake_outputs.shape) == [2, 196]