diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a843f139..fd4ed476 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -64,7 +64,7 @@ class Attention(nn.Module): num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, - scale_attn_norm: bool = False, + scale_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., @@ -80,7 +80,7 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() + self.norm = norm_layer(dim) if scale_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -151,7 +151,7 @@ class Block(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - scale_attn_norm=scale_attn_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -205,7 +205,7 @@ class ResPostBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, - scale_attn_norm=scale_attn_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -253,6 +253,8 @@ class ParallelScalingBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, proj_drop: float = 0., attn_drop: float = 0., @@ -264,6 +266,7 @@ class ParallelScalingBlock(nn.Module): ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 @@ -348,6 +351,8 @@ class ParallelThingsBlock(nn.Module): mlp_ratio: float = 4., qkv_bias: bool = False, qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, proj_bias: bool = True, init_values: Optional[float] = None, proj_drop: float = 0., @@ -369,6 +374,7 @@ class ParallelThingsBlock(nn.Module): num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + scale_norm=scale_attn_norm, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -383,6 +389,7 @@ class ParallelThingsBlock(nn.Module): dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, bias=proj_bias, drop=proj_drop, )), diff --git a/timm/models/vitamin.py b/timm/models/vitamin.py index f8e2a9a1..7d023e0d 100644 --- a/timm/models/vitamin.py +++ b/timm/models/vitamin.py @@ -259,11 +259,12 @@ class GeGluMlp(nn.Module): in_features, hidden_features, act_layer = 'gelu', + norm_layer = None, bias = True, drop = 0.0, ): super().__init__() - norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) + norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6) self.norm = norm_layer(in_features) self.w0 = nn.Linear(in_features, hidden_features, bias=bias)