Fix parallel blocks missing scale args and vitamin MLP

This commit is contained in:
Ross Wightman 2025-05-29 11:34:19 -07:00
parent 2ca94a6ce4
commit 3a3d98bc38
2 changed files with 13 additions and 5 deletions

View File

@ -64,7 +64,7 @@ class Attention(nn.Module):
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
@ -80,7 +80,7 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.norm = norm_layer(dim) if scale_attn_norm else nn.Identity()
self.norm = norm_layer(dim) if scale_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
@ -151,7 +151,7 @@ class Block(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm,
scale_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
@ -205,7 +205,7 @@ class ResPostBlock(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_attn_norm=scale_attn_norm,
scale_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
@ -253,6 +253,8 @@ class ParallelScalingBlock(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
@ -264,6 +266,7 @@ class ParallelScalingBlock(nn.Module):
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
@ -348,6 +351,8 @@ class ParallelThingsBlock(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
init_values: Optional[float] = None,
proj_drop: float = 0.,
@ -369,6 +374,7 @@ class ParallelThingsBlock(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
@ -383,6 +389,7 @@ class ParallelThingsBlock(nn.Module):
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp_norm else None,
bias=proj_bias,
drop=proj_drop,
)),

View File

@ -259,11 +259,12 @@ class GeGluMlp(nn.Module):
in_features,
hidden_features,
act_layer = 'gelu',
norm_layer = None,
bias = True,
drop = 0.0,
):
super().__init__()
norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
norm_layer = partial(get_norm_layer(norm_layer or 'layernorm'), eps=1e-6)
self.norm = norm_layer(in_features)
self.w0 = nn.Linear(in_features, hidden_features, bias=bias)