Fix parallel blocks missing scale args and vitamin MLP

This commit is contained in:
Ross Wightman 2025-05-29 11:34:19 -07:00
parent 2ca94a6ce4
commit 3a3d98bc38
2 changed files with 13 additions and 5 deletions

View File

@ -64,7 +64,7 @@ class Attention(nn.Module):
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = False, qkv_bias: bool = False,
qk_norm: bool = False, qk_norm: bool = False,
scale_attn_norm: bool = False, scale_norm: bool = False,
proj_bias: bool = True, proj_bias: bool = True,
attn_drop: float = 0., attn_drop: float = 0.,
proj_drop: float = 0., proj_drop: float = 0.,
@ -80,7 +80,7 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity() self.norm = norm_layer(dim) if scale_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
@ -151,7 +151,7 @@ class Block(nn.Module):
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm, scale_norm=scale_attn_norm,
proj_bias=proj_bias, proj_bias=proj_bias,
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
@ -205,7 +205,7 @@ class ResPostBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm, scale_norm=scale_attn_norm,
proj_bias=proj_bias, proj_bias=proj_bias,
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
@ -253,6 +253,8 @@ class ParallelScalingBlock(nn.Module):
mlp_ratio: float = 4., mlp_ratio: float = 4.,
qkv_bias: bool = False, qkv_bias: bool = False,
qk_norm: bool = False, qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True, proj_bias: bool = True,
proj_drop: float = 0., proj_drop: float = 0.,
attn_drop: float = 0., attn_drop: float = 0.,
@ -264,6 +266,7 @@ class ParallelScalingBlock(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = dim // num_heads self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5 self.scale = self.head_dim ** -0.5
@ -348,6 +351,8 @@ class ParallelThingsBlock(nn.Module):
mlp_ratio: float = 4., mlp_ratio: float = 4.,
qkv_bias: bool = False, qkv_bias: bool = False,
qk_norm: bool = False, qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True, proj_bias: bool = True,
init_values: Optional[float] = None, init_values: Optional[float] = None,
proj_drop: float = 0., proj_drop: float = 0.,
@ -369,6 +374,7 @@ class ParallelThingsBlock(nn.Module):
num_heads=num_heads, num_heads=num_heads,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm, qk_norm=qk_norm,
scale_norm=scale_attn_norm,
proj_bias=proj_bias, proj_bias=proj_bias,
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
@ -383,6 +389,7 @@ class ParallelThingsBlock(nn.Module):
dim, dim,
hidden_features=int(dim * mlp_ratio), hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, act_layer=act_layer,
norm_layer=norm_layer if scale_mlp_norm else None,
bias=proj_bias, bias=proj_bias,
drop=proj_drop, drop=proj_drop,
)), )),

View File

@ -259,11 +259,12 @@ class GeGluMlp(nn.Module):
in_features, in_features,
hidden_features, hidden_features,
act_layer = 'gelu', act_layer = 'gelu',
norm_layer = None,
bias = True, bias = True,
drop = 0.0, drop = 0.0,
): ):
super().__init__() super().__init__()
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6) norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6)
self.norm = norm_layer(in_features) self.norm = norm_layer(in_features)
self.w0 = nn.Linear(in_features, hidden_features, bias=bias) self.w0 = nn.Linear(in_features, hidden_features, bias=bias)