[Refactor] BEiT refactor (#1705)
* [Refactor] BEiT refactor * [Fix] Fix arch zoo * [Fix] Fix arch zoo * [Fix] Fix freeze stages * [Fix] Fix freeze ln2 * [Fix] Fix freezing vit ln2pull/1709/head
parent
78d0ddc852
commit
5c43d3ef42
|
@ -7,11 +7,13 @@ import torch.nn as nn
|
|||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
|
||||
from mmengine.model import BaseModule, ModuleList
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed,
|
||||
resize_relative_position_bias_table, to_2tuple)
|
||||
from .vision_transformer import TransformerEncoderLayer, VisionTransformer
|
||||
from .base_backbone import BaseBackbone
|
||||
from .vision_transformer import TransformerEncoderLayer
|
||||
|
||||
|
||||
class RelativePositionBias(BaseModule):
|
||||
|
@ -212,7 +214,7 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer):
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiTViT(VisionTransformer):
|
||||
class BEiTViT(BaseBackbone):
|
||||
"""Backbone for BEiT.
|
||||
|
||||
A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers
|
||||
|
@ -282,6 +284,62 @@ class BEiTViT(VisionTransformer):
|
|||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
['s', 'small'], {
|
||||
'embed_dims': 768,
|
||||
'num_layers': 8,
|
||||
'num_heads': 8,
|
||||
'feedforward_channels': 768 * 3,
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
'embed_dims': 768,
|
||||
'num_layers': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 3072
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['l', 'large'], {
|
||||
'embed_dims': 1024,
|
||||
'num_layers': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['eva-g', 'eva-giant'],
|
||||
{
|
||||
# The implementation in EVA
|
||||
# <https://arxiv.org/abs/2211.07636>
|
||||
'embed_dims': 1408,
|
||||
'num_layers': 40,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 6144
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-t', 'deit-tiny'], {
|
||||
'embed_dims': 192,
|
||||
'num_layers': 12,
|
||||
'num_heads': 3,
|
||||
'feedforward_channels': 192 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-s', 'deit-small'], {
|
||||
'embed_dims': 384,
|
||||
'num_layers': 12,
|
||||
'num_heads': 6,
|
||||
'feedforward_channels': 384 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-b', 'deit-base'], {
|
||||
'embed_dims': 768,
|
||||
'num_layers': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 768 * 4
|
||||
}),
|
||||
}
|
||||
num_extra_tokens = 1 # class token
|
||||
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
|
@ -300,12 +358,12 @@ class BEiTViT(VisionTransformer):
|
|||
use_abs_pos_emb=False,
|
||||
use_rel_pos_bias=True,
|
||||
use_shared_rel_pos_bias=False,
|
||||
layer_scale_init_value=0.1,
|
||||
interpolate_mode='bicubic',
|
||||
layer_scale_init_value=0.1,
|
||||
patch_cfg=dict(),
|
||||
layer_cfgs=dict(),
|
||||
init_cfg=None):
|
||||
super(VisionTransformer, self).__init__(init_cfg)
|
||||
super(BEiTViT, self).__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
|
@ -345,6 +403,7 @@ class BEiTViT(VisionTransformer):
|
|||
self.out_type = out_type
|
||||
|
||||
# Set cls token
|
||||
self.with_cls_token = with_cls_token
|
||||
if with_cls_token:
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
self.num_extra_tokens = 1
|
||||
|
@ -426,6 +485,87 @@ class BEiTViT(VisionTransformer):
|
|||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return self.ln1
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return self.ln2
|
||||
|
||||
def init_weights(self):
|
||||
super(BEiTViT, self).init_weights()
|
||||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if (not self.with_cls_token
|
||||
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
|
||||
# Remove cls token from state dict if it's not used.
|
||||
state_dict[name] = state_dict[name][:, 1:]
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmengine.logging import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.info(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.')
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.init_out_size
|
||||
|
||||
state_dict[name] = resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(*args, **kwargs):
|
||||
"""Interface for backward-compatibility."""
|
||||
return resize_pos_embed(*args, **kwargs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
# freeze position embedding
|
||||
if self.pos_embed is not None:
|
||||
self.pos_embed.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.drop_after_pos.eval()
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze cls_token
|
||||
if self.with_cls_token:
|
||||
self.cls_token.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.layers[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.layers):
|
||||
if self.final_norm:
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.out_type == 'avg_featmap':
|
||||
self.ln2.eval()
|
||||
for param in self.ln2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
@ -520,3 +660,38 @@ class BEiTViT(VisionTransformer):
|
|||
index_buffer = ckpt_key.replace('bias_table', 'index')
|
||||
if index_buffer in state_dict:
|
||||
del state_dict[index_buffer]
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
prefix (str): The prefix for the parameter.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the num of layers.
|
||||
|
||||
Note:
|
||||
The first depth is the stem module (``layer_depth=0``), and the
|
||||
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
||||
"""
|
||||
num_layers = self.num_layers + 2
|
||||
|
||||
if not param_name.startswith(prefix):
|
||||
# For subsequent module like head
|
||||
return num_layers - 1, num_layers
|
||||
|
||||
param_name = param_name[len(prefix):]
|
||||
|
||||
if param_name in ('cls_token', 'pos_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('patch_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('layers'):
|
||||
layer_id = int(param_name.split('.')[1])
|
||||
layer_depth = layer_id + 1
|
||||
else:
|
||||
layer_depth = num_layers - 1
|
||||
|
||||
return layer_depth, num_layers
|
||||
|
|
|
@ -444,10 +444,16 @@ class VisionTransformer(BaseBackbone):
|
|||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
if self.frozen_stages == len(self.layers) and self.final_norm:
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
if self.frozen_stages == len(self.layers):
|
||||
if self.final_norm:
|
||||
self.ln1.eval()
|
||||
for param in self.ln1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if self.out_type == 'avg_featmap':
|
||||
self.ln2.eval()
|
||||
for param in self.ln2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
|
|
Loading…
Reference in New Issue