fix_init on vit & relpos vit

This commit is contained in:
Ross Wightman 2024-02-10 20:15:37 -08:00
parent 935950cc11
commit ac1b08deb6
2 changed files with 85 additions and 56 deletions

View File

@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0., drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
fix_init: bool = False,
embed_layer: Callable = PatchEmbed, embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None,
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate. attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate. drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme. weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer. embed_layer: Patch embedding layer.
norm_layer: Normalization layer. norm_layer: Normalization layer.
act_layer: MLP activation layer. act_layer: MLP activation layer.
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
if weight_init != 'skip': if weight_init != 'skip':
self.init_weights(weight_init) self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '') assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
module.init_weights() module.init_weights()
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None: def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
if 'jax' in mode: if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias) return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode: elif 'moco' in mode:

View File

@ -7,7 +7,11 @@ Hacked together by / Copyright 2022, Ross Wightman
import logging import logging
import math import math
from functools import partial from functools import partial
from typing import Optional, Tuple from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,9 +19,11 @@ from torch.jit import Final
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module):
def __init__( def __init__(
self, self,
img_size=224, img_size: Union[int, Tuple[int, int]] = 224,
patch_size=16, patch_size: Union[int, Tuple[int, int]] = 16,
in_chans=3, in_chans: int = 3,
num_classes=1000, num_classes: int = 1000,
global_pool='avg', global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
embed_dim=768, embed_dim: int = 768,
depth=12, depth: int = 12,
num_heads=12, num_heads: int = 12,
mlp_ratio=4., mlp_ratio: float = 4.,
qkv_bias=True, qkv_bias: bool = True,
qk_norm=False, qk_norm: bool = False,
init_values=1e-6, init_values: Optional[float] = 1e-6,
class_token=False, class_token: bool = False,
fc_norm=False, fc_norm: bool = False,
rel_pos_type='mlp', rel_pos_type: str = 'mlp',
rel_pos_dim=None, rel_pos_dim: Optional[int] = None,
shared_rel_pos=False, shared_rel_pos: bool = False,
drop_rate=0., drop_rate: float = 0.,
proj_drop_rate=0., proj_drop_rate: float = 0.,
attn_drop_rate=0., attn_drop_rate: float = 0.,
drop_path_rate=0., drop_path_rate: float = 0.,
weight_init='skip', weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
embed_layer=PatchEmbed, fix_init: bool = False,
norm_layer=None, embed_layer: Type[nn.Module] = PatchEmbed,
act_layer=None, norm_layer: Optional[LayerType] = None,
block_fn=RelPosBlock act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = RelPosBlock
): ):
""" """
Args: Args:
img_size (int, tuple): input image size img_size: input image size
patch_size (int, tuple): patch size patch_size: patch size
in_chans (int): number of input channels in_chans: number of input channels
num_classes (int): number of classes for classification head num_classes: number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'avg') global_pool: type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension embed_dim: embedding dimension
depth (int): depth of transformer depth: depth of transformer
num_heads (int): number of attention heads num_heads: number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio: ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias: enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention qk_norm: Enable normalization of query and key in attention
init_values: (float): layer-scale init values init_values: layer-scale init values
class_token (bool): use class token (default: False) class_token: use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool fc_norm: use pre classifier norm instead of pre-pool
rel_pos_ty pe (str): type of relative position rel_pos_type: type of relative position
shared_rel_pos (bool): share relative pos across all blocks shared_rel_pos: share relative pos across all blocks
drop_rate (float): dropout rate drop_rate: dropout rate
proj_drop_rate (float): projection dropout rate proj_drop_rate: projection dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate: attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate: stochastic depth rate
weight_init (str): weight init scheme weight_init: weight init scheme
embed_layer (nn.Module): patch embedding layer fix_init: apply weight initialization fix (scaling w/ layer index)
norm_layer: (nn.Module): normalization layer embed_layer: patch embedding layer
act_layer: (nn.Module): MLP activation layer norm_layer: normalization layer
act_layer: MLP activation layer
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'token') assert global_pool in ('', 'avg', 'token')
@ -332,13 +340,22 @@ class VisionTransformerRelPos(nn.Module):
if weight_init != 'skip': if weight_init != 'skip':
self.init_weights(weight_init) self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode=''): def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '') assert mode in ('jax', 'moco', '')
if self.cls_token is not None: if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6) nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently named_apply(get_init_weights_vit(mode), self)
#named_apply(get_init_weights_vit(mode, head_bias), self)
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):