From 5c43d3ef42602a8e98ebcbb95944e910afc88f1a Mon Sep 17 00:00:00 2001 From: fanqiNO1 <75657629+fanqiNO1@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:49:41 +0800 Subject: [PATCH] [Refactor] BEiT refactor (#1705) * [Refactor] BEiT refactor * [Fix] Fix arch zoo * [Fix] Fix arch zoo * [Fix] Fix freeze stages * [Fix] Fix freeze ln2 * [Fix] Fix freezing vit ln2 --- mmpretrain/models/backbones/beit.py | 183 +++++++++++++++++- .../models/backbones/vision_transformer.py | 14 +- 2 files changed, 189 insertions(+), 8 deletions(-) diff --git a/mmpretrain/models/backbones/beit.py b/mmpretrain/models/backbones/beit.py index 8f64ae20..3c7d9085 100644 --- a/mmpretrain/models/backbones/beit.py +++ b/mmpretrain/models/backbones/beit.py @@ -7,11 +7,13 @@ import torch.nn as nn from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.transformer import FFN, PatchEmbed from mmengine.model import BaseModule, ModuleList +from mmengine.model.weight_init import trunc_normal_ from mmpretrain.registry import MODELS from ..utils import (BEiTAttention, build_norm_layer, resize_pos_embed, resize_relative_position_bias_table, to_2tuple) -from .vision_transformer import TransformerEncoderLayer, VisionTransformer +from .base_backbone import BaseBackbone +from .vision_transformer import TransformerEncoderLayer class RelativePositionBias(BaseModule): @@ -212,7 +214,7 @@ class BEiTTransformerEncoderLayer(TransformerEncoderLayer): @MODELS.register_module() -class BEiTViT(VisionTransformer): +class BEiTViT(BaseBackbone): """Backbone for BEiT. A PyTorch implement of : `BEiT: BERT Pre-Training of Image Transformers @@ -282,6 +284,62 @@ class BEiTViT(VisionTransformer): init_cfg (dict, optional): Initialization config dict. Defaults to None. """ + arch_zoo = { + **dict.fromkeys( + ['s', 'small'], { + 'embed_dims': 768, + 'num_layers': 8, + 'num_heads': 8, + 'feedforward_channels': 768 * 3, + }), + **dict.fromkeys( + ['b', 'base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 3072 + }), + **dict.fromkeys( + ['l', 'large'], { + 'embed_dims': 1024, + 'num_layers': 24, + 'num_heads': 16, + 'feedforward_channels': 4096 + }), + **dict.fromkeys( + ['eva-g', 'eva-giant'], + { + # The implementation in EVA + # + 'embed_dims': 1408, + 'num_layers': 40, + 'num_heads': 16, + 'feedforward_channels': 6144 + }), + **dict.fromkeys( + ['deit-t', 'deit-tiny'], { + 'embed_dims': 192, + 'num_layers': 12, + 'num_heads': 3, + 'feedforward_channels': 192 * 4 + }), + **dict.fromkeys( + ['deit-s', 'deit-small'], { + 'embed_dims': 384, + 'num_layers': 12, + 'num_heads': 6, + 'feedforward_channels': 384 * 4 + }), + **dict.fromkeys( + ['deit-b', 'deit-base'], { + 'embed_dims': 768, + 'num_layers': 12, + 'num_heads': 12, + 'feedforward_channels': 768 * 4 + }), + } + num_extra_tokens = 1 # class token + OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} def __init__(self, arch='base', @@ -300,12 +358,12 @@ class BEiTViT(VisionTransformer): use_abs_pos_emb=False, use_rel_pos_bias=True, use_shared_rel_pos_bias=False, - layer_scale_init_value=0.1, interpolate_mode='bicubic', + layer_scale_init_value=0.1, patch_cfg=dict(), layer_cfgs=dict(), init_cfg=None): - super(VisionTransformer, self).__init__(init_cfg) + super(BEiTViT, self).__init__(init_cfg) if isinstance(arch, str): arch = arch.lower() @@ -345,6 +403,7 @@ class BEiTViT(VisionTransformer): self.out_type = out_type # Set cls token + self.with_cls_token = with_cls_token if with_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) self.num_extra_tokens = 1 @@ -426,6 +485,87 @@ class BEiTViT(VisionTransformer): if self.frozen_stages > 0: self._freeze_stages() + @property + def norm1(self): + return self.ln1 + + @property + def norm2(self): + return self.ln2 + + def init_weights(self): + super(BEiTViT, self).init_weights() + + if not (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs): + name = prefix + 'pos_embed' + if name not in state_dict.keys(): + return + + ckpt_pos_embed_shape = state_dict[name].shape + if (not self.with_cls_token + and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1): + # Remove cls token from state dict if it's not used. + state_dict[name] = state_dict[name][:, 1:] + ckpt_pos_embed_shape = state_dict[name].shape + + if self.pos_embed.shape != ckpt_pos_embed_shape: + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.info( + f'Resize the pos_embed shape from {ckpt_pos_embed_shape} ' + f'to {self.pos_embed.shape}.') + + ckpt_pos_embed_shape = to_2tuple( + int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))) + pos_embed_shape = self.patch_embed.init_out_size + + state_dict[name] = resize_pos_embed(state_dict[name], + ckpt_pos_embed_shape, + pos_embed_shape, + self.interpolate_mode, + self.num_extra_tokens) + + @staticmethod + def resize_pos_embed(*args, **kwargs): + """Interface for backward-compatibility.""" + return resize_pos_embed(*args, **kwargs) + + def _freeze_stages(self): + # freeze position embedding + if self.pos_embed is not None: + self.pos_embed.requires_grad = False + # set dropout to eval model + self.drop_after_pos.eval() + # freeze patch embedding + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + # freeze cls_token + if self.with_cls_token: + self.cls_token.requires_grad = False + # freeze layers + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + # freeze the last layer norm + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False + def forward(self, x): B = x.shape[0] x, patch_resolution = self.patch_embed(x) @@ -520,3 +660,38 @@ class BEiTViT(VisionTransformer): index_buffer = ckpt_key.replace('bias_table', 'index') if index_buffer in state_dict: del state_dict[index_buffer] + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + + Note: + The first depth is the stem module (``layer_depth=0``), and the + last depth is the subsequent module (``layer_depth=num_layers-1``) + """ + num_layers = self.num_layers + 2 + + if not param_name.startswith(prefix): + # For subsequent module like head + return num_layers - 1, num_layers + + param_name = param_name[len(prefix):] + + if param_name in ('cls_token', 'pos_embed'): + layer_depth = 0 + elif param_name.startswith('patch_embed'): + layer_depth = 0 + elif param_name.startswith('layers'): + layer_id = int(param_name.split('.')[1]) + layer_depth = layer_id + 1 + else: + layer_depth = num_layers - 1 + + return layer_depth, num_layers diff --git a/mmpretrain/models/backbones/vision_transformer.py b/mmpretrain/models/backbones/vision_transformer.py index 21572f36..d77ac863 100644 --- a/mmpretrain/models/backbones/vision_transformer.py +++ b/mmpretrain/models/backbones/vision_transformer.py @@ -444,10 +444,16 @@ class VisionTransformer(BaseBackbone): for param in m.parameters(): param.requires_grad = False # freeze the last layer norm - if self.frozen_stages == len(self.layers) and self.final_norm: - self.ln1.eval() - for param in self.ln1.parameters(): - param.requires_grad = False + if self.frozen_stages == len(self.layers): + if self.final_norm: + self.ln1.eval() + for param in self.ln1.parameters(): + param.requires_grad = False + + if self.out_type == 'avg_featmap': + self.ln2.eval() + for param in self.ln2.parameters(): + param.requires_grad = False def forward(self, x): B = x.shape[0]