473 lines
18 KiB
Python
473 lines
18 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Part of code is modified from BEiT
|
|
# https://github.com/microsoft/unilm/blob/master/beit/dall_e/encoder.py
|
|
import math
|
|
from collections import OrderedDict
|
|
from functools import partial
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmengine.model import BaseModule
|
|
from mmengine.model.weight_init import trunc_normal_
|
|
|
|
from mmpretrain.models.backbones import BEiTViT
|
|
from mmpretrain.registry import MODELS
|
|
from mmpretrain.structures import DataSample
|
|
from ..utils import build_2d_sincos_position_embedding
|
|
from .base import BaseSelfSupervisor
|
|
|
|
|
|
class Conv2d(nn.Module):
|
|
"""Rewrite Conv2d module according to DALL-E code."""
|
|
|
|
def __init__(self,
|
|
n_in: int,
|
|
n_out: int,
|
|
kw: int,
|
|
use_float16: bool = True,
|
|
device: torch.device = torch.device('cpu'),
|
|
requires_grad: bool = False) -> None:
|
|
super().__init__()
|
|
|
|
w = torch.empty((n_out, n_in, kw, kw),
|
|
dtype=torch.float32,
|
|
device=device,
|
|
requires_grad=requires_grad)
|
|
w.normal_(std=1 / math.sqrt(n_in * kw**2))
|
|
|
|
b = torch.zeros((n_out, ),
|
|
dtype=torch.float32,
|
|
device=device,
|
|
requires_grad=requires_grad)
|
|
self.kw = kw
|
|
self.w, self.b = nn.Parameter(w), nn.Parameter(b)
|
|
self.use_float16 = use_float16
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self.use_float16 and 'cuda' in self.w.device.type:
|
|
if x.dtype != torch.float16:
|
|
x = x.half()
|
|
|
|
w, b = self.w.half(), self.b.half()
|
|
else:
|
|
if x.dtype != torch.float32:
|
|
x = x.float()
|
|
|
|
w, b = self.w, self.b
|
|
|
|
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2)
|
|
|
|
|
|
class EncoderBlock(nn.Module):
|
|
"""Rewrite EncoderBlock module according to DALL-E code."""
|
|
|
|
def __init__(self,
|
|
n_in: int,
|
|
n_out: int,
|
|
n_layers: int,
|
|
device: torch.device = None,
|
|
requires_grad: bool = False) -> None:
|
|
super().__init__()
|
|
self.n_hid = n_out // 4
|
|
self.post_gain = 1 / (n_layers**2)
|
|
|
|
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
|
|
self.id_path = make_conv(n_in, n_out,
|
|
1) if n_in != n_out else nn.Identity()
|
|
self.res_path = nn.Sequential(
|
|
OrderedDict([
|
|
('relu_1', nn.ReLU()),
|
|
('conv_1', make_conv(n_in, self.n_hid, 3)),
|
|
('relu_2', nn.ReLU()),
|
|
('conv_2', make_conv(self.n_hid, self.n_hid, 3)),
|
|
('relu_3', nn.ReLU()),
|
|
('conv_3', make_conv(self.n_hid, self.n_hid, 3)),
|
|
('relu_4', nn.ReLU()),
|
|
('conv_4', make_conv(self.n_hid, n_out, 1)),
|
|
]))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.id_path(x) + self.post_gain * self.res_path(x)
|
|
|
|
|
|
@MODELS.register_module(name='DALL-E')
|
|
class DALLEEncoder(BaseModule):
|
|
"""DALL-E Encoder for feature extraction.
|
|
|
|
Args:
|
|
group_count (int): Number of groups in DALL-E encoder. Defaults to 4.
|
|
n_hid (int): Dimension of hidden layers. Defaults to 256.
|
|
n_blk_per_group (int): Number of blocks per group. Defaults to 2.
|
|
input_channels: (int): The channels of input images. Defaults to 3.
|
|
vocab_size (int): Vocabulary size, indicating the number of classes.
|
|
Defaults to 8192.
|
|
device (torch.device): Device of parameters. Defaults to
|
|
``torch.device('cpu')``.
|
|
requires_grad (bool): Require gradient or not. Defaults to False.
|
|
init_cfg (Union[List[dict], dict], optional): Config dict for weight
|
|
initialization. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
group_count: int = 4,
|
|
n_hid: int = 256,
|
|
n_blk_per_group: int = 2,
|
|
input_channels: int = 3,
|
|
vocab_size: int = 8192,
|
|
device: torch.device = torch.device('cpu'),
|
|
requires_grad: bool = False,
|
|
init_cfg: Union[dict, List[dict], None] = None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.input_channels = input_channels
|
|
|
|
blk_range = range(n_blk_per_group)
|
|
n_layers = group_count * n_blk_per_group
|
|
make_conv = partial(Conv2d, device=device, requires_grad=requires_grad)
|
|
make_blk = partial(
|
|
EncoderBlock,
|
|
n_layers=n_layers,
|
|
device=device,
|
|
requires_grad=requires_grad)
|
|
|
|
self.blocks = nn.Sequential(
|
|
OrderedDict([
|
|
('input', make_conv(input_channels, 1 * n_hid, 7)),
|
|
('group_1',
|
|
nn.Sequential(
|
|
OrderedDict([
|
|
*[(f'block_{i + 1}', make_blk(1 * n_hid, 1 * n_hid))
|
|
for i in blk_range],
|
|
('pool', nn.MaxPool2d(kernel_size=2)),
|
|
]))),
|
|
('group_2',
|
|
nn.Sequential(
|
|
OrderedDict([
|
|
*[(f'block_{i + 1}',
|
|
make_blk(1 * n_hid if i == 0 else 2 * n_hid,
|
|
2 * n_hid)) for i in blk_range],
|
|
('pool', nn.MaxPool2d(kernel_size=2)),
|
|
]))),
|
|
('group_3',
|
|
nn.Sequential(
|
|
OrderedDict([
|
|
*[(f'block_{i + 1}',
|
|
make_blk(2 * n_hid if i == 0 else 4 * n_hid,
|
|
4 * n_hid)) for i in blk_range],
|
|
('pool', nn.MaxPool2d(kernel_size=2)),
|
|
]))),
|
|
('group_4',
|
|
nn.Sequential(
|
|
OrderedDict([
|
|
*[(f'block_{i + 1}',
|
|
make_blk(4 * n_hid if i == 0 else 8 * n_hid,
|
|
8 * n_hid)) for i in blk_range],
|
|
]))),
|
|
('output',
|
|
nn.Sequential(
|
|
OrderedDict([
|
|
('relu', nn.ReLU()),
|
|
('conv',
|
|
make_conv(
|
|
8 * n_hid, vocab_size, 1, use_float16=False)),
|
|
]))),
|
|
]))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Forward function of DALL-E encoder.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input images with shape (B, C, H, W).
|
|
|
|
Returns:
|
|
torch.Tensor: The output with shape (B, vocab_size, h, w).
|
|
"""
|
|
x = x.float()
|
|
if len(x.shape) != 4:
|
|
raise ValueError(f'input shape {x.shape} is not 4d')
|
|
if x.shape[1] != self.input_channels:
|
|
raise ValueError(f'input has {x.shape[1]} channels but model \
|
|
built for {self.input_channels}')
|
|
if x.dtype != torch.float32:
|
|
raise ValueError('input must have dtype torch.float32')
|
|
|
|
return self.blocks(x)
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CAEPretrainViT(BEiTViT):
|
|
"""Vision Transformer for CAE pre-training and the implementation is based
|
|
on BEiTViT.
|
|
|
|
Args:
|
|
arch (str | dict): Vision Transformer architecture. Default: 'b'
|
|
img_size (int | tuple): Input image size
|
|
patch_size (int | tuple): The patch size
|
|
out_indices (Sequence | int): Output from which stages.
|
|
Defaults to -1, means the last stage.
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Defaults to 0.
|
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
|
bias (bool | str): The option to add leanable bias for q, k, v. If bias
|
|
is True, it will add leanable bias. If bias is 'qv_bias', it will
|
|
only add leanable bias for q, v. If bias is False, it will not add
|
|
bias for q, k, v. Default to 'qv_bias'.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
final_norm (bool): Whether to add a additional layer to normalize
|
|
final feature map. Defaults to True.
|
|
out_type (str): The type of output features. Please choose from
|
|
|
|
- ``"cls_token"``: The class token tensor with shape (B, C).
|
|
- ``"featmap"``: The feature map tensor from the patch tokens
|
|
with shape (B, C, H, W).
|
|
- ``"avg_featmap"``: The global averaged feature map tensor
|
|
with shape (B, C).
|
|
- ``"raw"``: The raw feature tensor includes patch tokens and
|
|
class tokens with shape (B, L, C).
|
|
|
|
It only works without input mask. Defaults to ``"avg_featmap"``.
|
|
interpolate_mode (str): Select the interpolate mode for position
|
|
embeding vector resize. Defaults to "bicubic".
|
|
layer_scale_init_value (float, optional): The init value of gamma in
|
|
BEiTTransformerEncoderLayer.
|
|
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
|
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
|
encoder. Defaults to an empty dict.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
arch: str = 'b',
|
|
img_size: int = 224,
|
|
patch_size: int = 16,
|
|
in_channels: int = 3,
|
|
out_indices: int = -1,
|
|
drop_rate: float = 0,
|
|
drop_path_rate: float = 0,
|
|
bias: bool = 'qv_bias',
|
|
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
|
final_norm: bool = True,
|
|
out_type: str = 'raw',
|
|
frozen_stages: int = -1,
|
|
use_abs_pos_emb: bool = True,
|
|
use_rel_pos_bias: bool = False,
|
|
use_shared_rel_pos_bias: bool = False,
|
|
layer_scale_init_value: float = None,
|
|
interpolate_mode: str = 'bicubic',
|
|
patch_cfg: dict = dict(),
|
|
layer_cfgs: dict = dict(),
|
|
init_cfg: dict = [
|
|
dict(type='Constant', val=1, layer=['LayerNorm']),
|
|
dict(type='TruncNormal', std=0.02, layer=['Conv2d']),
|
|
dict(type='Xavier', distribution='uniform', layer=['Linear'])
|
|
]
|
|
) -> None:
|
|
super().__init__(
|
|
arch=arch,
|
|
img_size=img_size,
|
|
patch_size=patch_size,
|
|
in_channels=in_channels,
|
|
out_indices=out_indices,
|
|
drop_rate=drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
bias=bias,
|
|
norm_cfg=norm_cfg,
|
|
final_norm=final_norm,
|
|
out_type=out_type,
|
|
with_cls_token=True,
|
|
frozen_stages=frozen_stages,
|
|
use_abs_pos_emb=use_abs_pos_emb,
|
|
use_rel_pos_bias=use_rel_pos_bias,
|
|
use_shared_rel_pos_bias=use_shared_rel_pos_bias,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
interpolate_mode=interpolate_mode,
|
|
patch_cfg=patch_cfg,
|
|
layer_cfgs=layer_cfgs,
|
|
init_cfg=init_cfg)
|
|
self.pos_embed.requires_grad = False
|
|
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
|
|
|
|
def init_weights(self) -> None:
|
|
"""Initialize position embedding, patch embedding and cls token."""
|
|
super().init_weights()
|
|
if not (isinstance(self.init_cfg, dict)
|
|
and self.init_cfg['type'] == 'Pretrained'):
|
|
# initialize position embedding in backbone
|
|
pos_embed = build_2d_sincos_position_embedding(
|
|
int(self.num_patches**.5),
|
|
self.pos_embed.shape[-1],
|
|
cls_token=True)
|
|
self.pos_embed.data.copy_(pos_embed.float())
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
def forward(self, x: torch.Tensor,
|
|
mask: Optional[torch.Tensor]) -> torch.Tensor:
|
|
"""Generate features for masked images.
|
|
|
|
This function generates mask images and get the hidden features for
|
|
visible patches.
|
|
|
|
The function supports two kind of forward behaviors. If the ``mask`` is
|
|
not ``None``, the forward function will be executed as masked image
|
|
modeling pre-training; if the ``mask`` is ``None``, the forward
|
|
function will call ``super().forward()``, which extract features from
|
|
images without mask.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
|
mask (torch.Tensor, optional): Mask for input, which is of shape
|
|
B x L.
|
|
|
|
Returns:
|
|
torch.Tensor: hidden features.
|
|
"""
|
|
if mask is None:
|
|
return super().forward(x)
|
|
|
|
else:
|
|
x, _ = self.patch_embed(x)
|
|
batch_size, _, dim = x.size()
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
|
|
# NOTE: unmasked embeddings
|
|
x_unmasked = x[~mask].reshape(batch_size, -1, dim)
|
|
x_unmasked = torch.cat((cls_tokens, x_unmasked), dim=1)
|
|
|
|
pos_embed = self.pos_embed.expand(batch_size, self.num_patches + 1,
|
|
dim)
|
|
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
|
|
batch_size, -1, dim)
|
|
pos_embed_unmasked = torch.cat(
|
|
(pos_embed[:, :1], pos_embed_unmasked), dim=1)
|
|
x_unmasked = x_unmasked + pos_embed_unmasked
|
|
|
|
x_unmasked = self.drop_after_pos(x_unmasked)
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
x_unmasked = layer(x=x_unmasked, rel_pos_bias=None)
|
|
|
|
if i == len(self.layers) - 1 and self.final_norm:
|
|
x_unmasked = self.norm1(x_unmasked)
|
|
|
|
return x_unmasked
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CAE(BaseSelfSupervisor):
|
|
"""CAE.
|
|
|
|
Implementation of `Context Autoencoder for Self-Supervised Representation
|
|
Learning <https://arxiv.org/abs/2202.03026>`_.
|
|
|
|
Args:
|
|
backbone (dict): Config dict for module of backbone.
|
|
neck (dict): Config dict for module of neck.
|
|
head (dict): Config dict for module of head functions.
|
|
target_generator: (dict, optional): The target_generator module to
|
|
generate targets for self-supervised learning optimization, such as
|
|
HOG, extracted features from other modules(DALL-E, CLIP), etc.
|
|
base_momentum (float): The base momentum coefficient for the target
|
|
network. Defaults to 0.0.
|
|
data_preprocessor (dict, optional): The config for preprocessing
|
|
input data. If None or no specified type, it will use
|
|
"SelfSupDataPreprocessor" as type.
|
|
See :class:`SelfSupDataPreprocessor` for more details.
|
|
Defaults to None.
|
|
init_cfg (Union[List[dict], dict], optional): Config dict for weight
|
|
initialization. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone: dict,
|
|
neck: dict,
|
|
head: dict,
|
|
target_generator: Optional[dict] = None,
|
|
base_momentum: float = 0.0,
|
|
data_preprocessor: Optional[dict] = None,
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
|
super().__init__(
|
|
backbone=backbone,
|
|
neck=neck,
|
|
head=head,
|
|
target_generator=target_generator,
|
|
data_preprocessor=data_preprocessor,
|
|
init_cfg=init_cfg)
|
|
|
|
self.momentum = base_momentum
|
|
self.teacher = MODELS.build(backbone)
|
|
|
|
def init_weights(self) -> None:
|
|
"""Initialize weights."""
|
|
super().init_weights()
|
|
|
|
# init the weights of teacher with those of backbone
|
|
for param_backbone, param_teacher in zip(self.backbone.parameters(),
|
|
self.teacher.parameters()):
|
|
param_teacher.detach()
|
|
param_teacher.data.copy_(param_backbone.data)
|
|
param_teacher.requires_grad = False
|
|
|
|
def momentum_update(self) -> None:
|
|
"""Momentum update of the teacher network."""
|
|
for param_bacbone, param_teacher in zip(self.backbone.parameters(),
|
|
self.teacher.parameters()):
|
|
param_teacher.data = param_teacher.data * self.momentum + \
|
|
param_bacbone.data * (1. - self.momentum)
|
|
|
|
def extract_feat(self, inputs: torch.Tensor):
|
|
return self.backbone(inputs, mask=None)
|
|
|
|
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample],
|
|
**kwargs) -> Dict[str, torch.Tensor]:
|
|
"""The forward function in training.
|
|
|
|
Args:
|
|
inputs (List[torch.Tensor]): The input images.
|
|
data_samples (List[DataSample]): All elements required
|
|
during the forward function.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A dictionary of loss components.
|
|
"""
|
|
mask = torch.stack([data_sample.mask for data_sample in data_samples])
|
|
mask = mask.flatten(1).to(torch.bool)
|
|
|
|
unmasked = self.backbone(inputs[0], mask)
|
|
|
|
# get the latent prediction for the masked patches
|
|
with torch.no_grad():
|
|
# inputs[0] is the prediction image
|
|
latent_target = self.teacher(inputs[0], ~mask)
|
|
latent_target = latent_target[:, 1:, :]
|
|
self.momentum_update()
|
|
|
|
pos_embed = self.backbone.pos_embed.expand(inputs[0].shape[0], -1, -1)
|
|
pos_embed_masked = pos_embed[:,
|
|
1:][mask].reshape(inputs[0].shape[0], -1,
|
|
pos_embed.shape[-1])
|
|
pos_embed_unmasked = pos_embed[:, 1:][~mask].reshape(
|
|
inputs[0].shape[0], -1, pos_embed.shape[-1])
|
|
|
|
# input the unmasked tokens and masked tokens to the decoder
|
|
logits, latent_pred = self.neck(unmasked[:, 1:], pos_embed_masked,
|
|
pos_embed_unmasked)
|
|
|
|
logits = logits.view(-1, logits.shape[-1])
|
|
# inputs[1] is the target image
|
|
logits_target = self.target_generator(inputs[1])
|
|
loss_main, loss_align = self.head.loss(logits, logits_target,
|
|
latent_pred, latent_target,
|
|
mask)
|
|
losses = dict()
|
|
|
|
losses['loss'] = loss_main + loss_align
|
|
losses['main'] = loss_main
|
|
losses['align'] = loss_align
|
|
return losses
|