fix_init on vit & relpos vit

pull/2092/head
Ross Wightman 2024-02-10 20:15:37 -08:00
parent 935950cc11
commit ac1b08deb6
2 changed files with 85 additions and 56 deletions

View File

@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
fix_init: bool = False,
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
fix_init: Apply weight initialization fix (scaling w/ layer index).
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def init_weights(self, mode: str = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
module.init_weights()
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None:
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:

View File

@ -7,7 +7,11 @@ Hacked together by / Copyright 2022, Ross Wightman
import logging
import math
from functools import partial
from typing import Optional, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
@ -15,9 +19,11 @@ from torch.jit import Final
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_norm=False,
init_values=1e-6,
class_token=False,
fc_norm=False,
rel_pos_type='mlp',
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=RelPosBlock
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = 1e-6,
class_token: bool = False,
fc_norm: bool = False,
rel_pos_type: str = 'mlp',
rel_pos_dim: Optional[int] = None,
shared_rel_pos: bool = False,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
fix_init: bool = False,
embed_layer: Type[nn.Module] = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = RelPosBlock
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention
init_values: (float): layer-scale init values
class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
drop_rate (float): dropout rate
proj_drop_rate (float): projection dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
img_size: input image size
patch_size: patch size
in_chans: number of input channels
num_classes: number of classes for classification head
global_pool: type of global pooling for final sequence (default: 'avg')
embed_dim: embedding dimension
depth: depth of transformer
num_heads: number of attention heads
mlp_ratio: ratio of mlp hidden dim to embedding dim
qkv_bias: enable bias for qkv if True
qk_norm: Enable normalization of query and key in attention
init_values: layer-scale init values
class_token: use class token (default: False)
fc_norm: use pre classifier norm instead of pre-pool
rel_pos_type: type of relative position
shared_rel_pos: share relative pos across all blocks
drop_rate: dropout rate
proj_drop_rate: projection dropout rate
attn_drop_rate: attention dropout rate
drop_path_rate: stochastic depth rate
weight_init: weight init scheme
fix_init: apply weight initialization fix (scaling w/ layer index)
embed_layer: patch embedding layer
norm_layer: normalization layer
act_layer: MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
@ -332,13 +340,22 @@ class VisionTransformerRelPos(nn.Module):
if weight_init != 'skip':
self.init_weights(weight_init)
if fix_init:
self.fix_init_weight()
def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '')
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently
#named_apply(get_init_weights_vit(mode, head_bias), self)
named_apply(get_init_weights_vit(mode), self)
def fix_init_weight(self):
def rescale(param, _layer_id):
param.div_(math.sqrt(2.0 * _layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@torch.jit.ignore
def no_weight_decay(self):