154 lines
5.8 KiB
Python
154 lines
5.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_norm_layer
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BEiTV2Neck(BaseModule):
|
|
"""Neck for BEiTV2 Pre-training.
|
|
|
|
This module construct the decoder for the final prediction.
|
|
|
|
Args:
|
|
num_layers (int): Number of encoder layers of neck. Defaults to 2.
|
|
early_layers (int): The layer index of the early output from the
|
|
backbone. Defaults to 9.
|
|
backbone_arch (str): Vision Transformer architecture. Defaults to base.
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Defaults to 0.
|
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
|
layer_scale_init_value (float): The initialization value for the
|
|
learnable scaling of attention and FFN. Defaults to 0.1.
|
|
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
|
if False, use shared relative position bias defined in backbone.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Defaults to ``dict(type='LN')``.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
arch_zoo = {
|
|
**dict.fromkeys(
|
|
['b', 'base'], {
|
|
'embed_dims': 768,
|
|
'depth': 12,
|
|
'num_heads': 12,
|
|
'feedforward_channels': 3072,
|
|
}),
|
|
**dict.fromkeys(
|
|
['l', 'large'], {
|
|
'embed_dims': 1024,
|
|
'depth': 24,
|
|
'num_heads': 16,
|
|
'feedforward_channels': 4096,
|
|
}),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
num_layers: int = 2,
|
|
early_layers: int = 9,
|
|
backbone_arch: str = 'base',
|
|
drop_rate: float = 0.,
|
|
drop_path_rate: float = 0.,
|
|
layer_scale_init_value: float = 0.1,
|
|
use_rel_pos_bias: bool = False,
|
|
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
|
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
|
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
|
) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if isinstance(backbone_arch, str):
|
|
backbone_arch = backbone_arch.lower()
|
|
assert backbone_arch in set(self.arch_zoo), \
|
|
(f'Arch {backbone_arch} is not in default archs '
|
|
f'{set(self.arch_zoo)}')
|
|
self.arch_settings = self.arch_zoo[backbone_arch]
|
|
else:
|
|
essential_keys = {
|
|
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
|
}
|
|
assert isinstance(backbone_arch, dict) and essential_keys <= set(
|
|
backbone_arch
|
|
), f'Custom arch needs a dict with keys {essential_keys}'
|
|
self.arch_settings = backbone_arch
|
|
|
|
# stochastic depth decay rule
|
|
self.early_layers = early_layers
|
|
depth = self.arch_settings['depth']
|
|
dpr = np.linspace(0, drop_path_rate,
|
|
max(depth, early_layers + num_layers))
|
|
|
|
self.patch_aggregation = nn.ModuleList()
|
|
for i in range(early_layers, early_layers + num_layers):
|
|
_layer_cfg = dict(
|
|
embed_dims=self.arch_settings['embed_dims'],
|
|
num_heads=self.arch_settings['num_heads'],
|
|
feedforward_channels=self.
|
|
arch_settings['feedforward_channels'],
|
|
drop_rate=drop_rate,
|
|
drop_path_rate=dpr[i],
|
|
norm_cfg=norm_cfg,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
window_size=None,
|
|
use_rel_pos_bias=use_rel_pos_bias)
|
|
self.patch_aggregation.append(
|
|
BEiTTransformerEncoderLayer(**_layer_cfg))
|
|
|
|
self.rescale_patch_aggregation_init_weight()
|
|
|
|
embed_dims = self.arch_settings['embed_dims']
|
|
_, norm = build_norm_layer(norm_cfg, embed_dims)
|
|
self.add_module('norm', norm)
|
|
|
|
def rescale_patch_aggregation_init_weight(self):
|
|
"""Rescale the initialized weights."""
|
|
|
|
def rescale(param, layer_id):
|
|
param.div_(math.sqrt(2.0 * layer_id))
|
|
|
|
for layer_id, layer in enumerate(self.patch_aggregation):
|
|
rescale(layer.attn.proj.weight.data,
|
|
self.early_layers + layer_id + 1)
|
|
rescale(layer.ffn.layers[1].weight.data,
|
|
self.early_layers + layer_id + 1)
|
|
|
|
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
|
|
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Get the latent prediction and final prediction.
|
|
|
|
Args:
|
|
x (Tuple[torch.Tensor]): Features of tokens.
|
|
rel_pos_bias (torch.Tensor): Shared relative position bias table.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]:
|
|
- ``x``: The final layer features from backbone, which are normed
|
|
in ``BEiTV2Neck``.
|
|
- ``x_cls_pt``: The early state features from backbone, which are
|
|
consist of final layer cls_token and early state patch_tokens
|
|
from backbone and sent to PatchAggregation layers in the neck.
|
|
"""
|
|
|
|
early_states, x = inputs[0], inputs[1]
|
|
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
|
|
for layer in self.patch_aggregation:
|
|
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
|
|
|
|
# shared norm
|
|
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
|
|
|
|
# remove cls_token
|
|
x = x[:, 1:]
|
|
x_cls_pt = x_cls_pt[:, 1:]
|
|
return x, x_cls_pt
|