diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index 11d9eeca..09472eed 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -132,7 +132,8 @@ class SwiGLU(nn.Module): def init_weights(self): # override init of fc1 w/ gate portion set to weight near zero, bias=1 - nn.init.ones_(self.fc1_g.bias) + if self.fc1_g.bias is not None: + nn.init.ones_(self.fc1_g.bias) nn.init.normal_(self.fc1_g.weight, std=1e-6) def forward(self, x): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 6bc93dd1..7935089d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -44,7 +44,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType + SwiGLU, get_act_layer, get_norm_layer, LayerType from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv @@ -65,6 +65,7 @@ class Attention(nn.Module): num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, @@ -80,7 +81,7 @@ class Attention(nn.Module): self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -130,6 +131,7 @@ class Block(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -145,6 +147,7 @@ class Block(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -157,6 +160,7 @@ class Block(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() @@ -176,6 +180,7 @@ class ResPostBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -192,6 +197,7 @@ class ResPostBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -203,6 +209,7 @@ class ResPostBlock(nn.Module): in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, ) self.norm2 = norm_layer(dim) @@ -236,6 +243,7 @@ class ParallelScalingBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., init_values: Optional[float] = None, @@ -266,11 +274,11 @@ class ParallelScalingBlock(nn.Module): self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.attn_out_proj = nn.Linear(dim, dim) + self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias) self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() - self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias) self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -330,6 +338,7 @@ class ParallelThingsBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + proj_bias: bool = True, init_values: Optional[float] = None, proj_drop: float = 0., attn_drop: float = 0., @@ -350,6 +359,7 @@ class ParallelThingsBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, @@ -363,6 +373,7 @@ class ParallelThingsBlock(nn.Module): dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + bias=proj_bias, drop=proj_drop, )), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), @@ -433,6 +444,7 @@ class VisionTransformer(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = True, qk_norm: bool = False, + proj_bias: bool = True, init_values: Optional[float] = None, class_token: bool = True, pos_embed: str = 'learn', @@ -452,6 +464,7 @@ class VisionTransformer(nn.Module): weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', fix_init: bool = False, embed_layer: Callable = PatchEmbed, + embed_norm_layer: Optional[LayerType] = None, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, @@ -483,6 +496,7 @@ class VisionTransformer(nn.Module): weight_init: Weight initialization scheme. fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. + embed_norm_layer: Normalization layer to use / override in patch embed module. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. @@ -493,6 +507,7 @@ class VisionTransformer(nn.Module): assert pos_embed in ('', 'none', 'learn') use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + embed_norm_layer = get_norm_layer(embed_norm_layer) act_layer = get_act_layer(act_layer) or nn.GELU self.num_classes = num_classes @@ -510,6 +525,8 @@ class VisionTransformer(nn.Module): if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + if embed_norm_layer is not None: + embed_args['norm_layer'] = embed_norm_layer self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, @@ -539,7 +556,7 @@ class VisionTransformer(nn.Module): self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule self.blocks = nn.Sequential(*[ block_fn( dim=embed_dim, @@ -547,6 +564,7 @@ class VisionTransformer(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, + proj_bias=proj_bias, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, @@ -1128,6 +1146,31 @@ def _convert_dinov2( return out_dict +def _convert_aimv2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + #import re + out_dict = {} + + for k, v in state_dict.items(): + k = k.replace('norm_1', 'norm1') + k = k.replace('norm_2', 'norm2') + k = k.replace('preprocessor.patchifier.', 'patch_embed.') + k = k.replace('preprocessor.pos_embed', 'pos_embed') + k = k.replace('trunk.', '') + k = k.replace('mlp.fc1', 'mlp.fc1_g') + k = k.replace('mlp.fc3', 'mlp.fc1_x') + k = k.replace('post_trunk_norm.', 'norm.') + # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + # out_dict[k.replace("w12", "fc1")] = v + # continue + # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): + # out_dict[k.replace("w3", "fc2")] = v + # continue + out_dict[k] = v + return out_dict + def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, @@ -1159,6 +1202,8 @@ def checkpoint_filter_fn( # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) out_dict['head.weight'] = state_dict['visual.head.proj.weight'] out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + elif 'preprocessor.patchifier.proj.weight' in state_dict: + state_dict = _convert_aimv2(state_dict, model) if prefix: # filter on & remove prefix string from keys @@ -2119,6 +2164,12 @@ default_cfgs = { input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, ), + 'vit_large_patch14_aimv2_224': _cfg( + hf_hub_id='apple/aimv2-large-patch14-224', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 224, 224), crop_pct=1.0, + num_classes=0), + 'test_vit.r160_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 160, 160), crop_pct=0.95), @@ -3390,6 +3441,21 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran return model +@register_model +def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + rms_norm = partial(RmsNorm, eps=1e-5) + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', norm_layer=rms_norm, embed_norm_layer=rms_norm, mlp_layer=SwiGLU, + qkv_bias=False, proj_bias=False, + ) + model = _create_vision_transformer( + 'vit_large_patch14_aimv2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT Test