fix_init on vit & relpos vit
parent
935950cc11
commit
ac1b08deb6
|
@ -421,6 +421,7 @@ class VisionTransformer(nn.Module):
|
|||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
||||
fix_init: bool = False,
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
|
@ -449,6 +450,7 @@ class VisionTransformer(nn.Module):
|
|||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
fix_init: Apply weight initialization fix (scaling w/ layer index).
|
||||
embed_layer: Patch embedding layer.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
|
@ -536,8 +538,18 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
|
||||
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def init_weights(self, mode: str = '') -> None:
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
@ -737,7 +749,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
|
|||
module.init_weights()
|
||||
|
||||
|
||||
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None:
|
||||
def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
|
||||
if 'jax' in mode:
|
||||
return partial(init_weights_vit_jax, head_bias=head_bias)
|
||||
elif 'moco' in mode:
|
||||
|
|
|
@ -7,7 +7,11 @@ Hacked together by / Copyright 2022, Ross Wightman
|
|||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, List
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -15,9 +19,11 @@ from torch.jit import Final
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -215,59 +221,61 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=1e-6,
|
||||
class_token=False,
|
||||
fc_norm=False,
|
||||
rel_pos_type='mlp',
|
||||
rel_pos_dim=None,
|
||||
shared_rel_pos=False,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='skip',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=RelPosBlock
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = 1e-6,
|
||||
class_token: bool = False,
|
||||
fc_norm: bool = False,
|
||||
rel_pos_type: str = 'mlp',
|
||||
rel_pos_dim: Optional[int] = None,
|
||||
shared_rel_pos: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
|
||||
fix_init: bool = False,
|
||||
embed_layer: Type[nn.Module] = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
block_fn: Type[nn.Module] = RelPosBlock
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'avg')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_norm (bool): Enable normalization of query and key in attention
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token (default: False)
|
||||
fc_norm (bool): use pre classifier norm instead of pre-pool
|
||||
rel_pos_ty pe (str): type of relative position
|
||||
shared_rel_pos (bool): share relative pos across all blocks
|
||||
drop_rate (float): dropout rate
|
||||
proj_drop_rate (float): projection dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
img_size: input image size
|
||||
patch_size: patch size
|
||||
in_chans: number of input channels
|
||||
num_classes: number of classes for classification head
|
||||
global_pool: type of global pooling for final sequence (default: 'avg')
|
||||
embed_dim: embedding dimension
|
||||
depth: depth of transformer
|
||||
num_heads: number of attention heads
|
||||
mlp_ratio: ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias: enable bias for qkv if True
|
||||
qk_norm: Enable normalization of query and key in attention
|
||||
init_values: layer-scale init values
|
||||
class_token: use class token (default: False)
|
||||
fc_norm: use pre classifier norm instead of pre-pool
|
||||
rel_pos_type: type of relative position
|
||||
shared_rel_pos: share relative pos across all blocks
|
||||
drop_rate: dropout rate
|
||||
proj_drop_rate: projection dropout rate
|
||||
attn_drop_rate: attention dropout rate
|
||||
drop_path_rate: stochastic depth rate
|
||||
weight_init: weight init scheme
|
||||
fix_init: apply weight initialization fix (scaling w/ layer index)
|
||||
embed_layer: patch embedding layer
|
||||
norm_layer: normalization layer
|
||||
act_layer: MLP activation layer
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
|
@ -332,13 +340,22 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
if fix_init:
|
||||
self.fix_init_weight()
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'moco', '')
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
# FIXME weight init scheme using PyTorch defaults curently
|
||||
#named_apply(get_init_weights_vit(mode, head_bias), self)
|
||||
named_apply(get_init_weights_vit(mode), self)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, _layer_id):
|
||||
param.div_(math.sqrt(2.0 * _layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
|
|
Loading…
Reference in New Issue