mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Torchscript fixes/hacks for rms_norm, refactor ParallelScalingBlock with manual combination of input projections, closer paper match
This commit is contained in:
parent
122621daef
commit
f77c04ff36
@ -90,7 +90,15 @@ def rms_norm(
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||
norm_ndim = len(normalized_shape)
|
||||
if torch.jit.is_scripting():
|
||||
# ndim = len(x.shape)
|
||||
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
|
||||
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
|
||||
assert norm_ndim == 1
|
||||
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
|
||||
else:
|
||||
dims = tuple(range(-1, -norm_ndim - 1, -1))
|
||||
v = torch.var(x, dim=dims, keepdim=True)
|
||||
x = x * torch.rsqrt(v + eps)
|
||||
if weight is not None:
|
||||
@ -104,10 +112,15 @@ def fast_rms_norm(
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.Tensor:
|
||||
if torch.jit.is_scripting() or not has_apex_rmsnorm:
|
||||
if torch.jit.is_scripting():
|
||||
# this must be by itself, cannot merge with has_apex_rmsnorm
|
||||
return rms_norm(x, normalized_shape, weight, eps)
|
||||
|
||||
if has_apex_rmsnorm:
|
||||
if weight is None:
|
||||
return fused_rms_norm(x, normalized_shape, eps)
|
||||
else:
|
||||
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
|
||||
|
||||
# fallback
|
||||
return rms_norm(x, normalized_shape, weight, eps)
|
||||
|
@ -122,6 +122,7 @@ class LayerNormExp2d(nn.LayerNorm):
|
||||
class RmsNorm(nn.Module):
|
||||
""" RmsNorm w/ fast (apex) norm if available
|
||||
"""
|
||||
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
elementwise_affine: bool
|
||||
|
@ -217,6 +217,8 @@ class ParallelScalingBlock(nn.Module):
|
||||
Based on:
|
||||
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
|
||||
"""
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
@ -232,33 +234,76 @@ class ParallelScalingBlock(nn.Module):
|
||||
norm_layer=nn.LayerNorm
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
mlp_hidden_dim = int(mlp_ratio * dim)
|
||||
in_proj_out_dim = mlp_hidden_dim + 3 * dim
|
||||
out_proj_in_dim = mlp_hidden_dim + dim
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.in_norm = norm_layer(dim)
|
||||
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
|
||||
self.in_split = [mlp_hidden_dim] + [dim] * 3
|
||||
if qkv_bias:
|
||||
self.register_buffer('qkv_bias', None)
|
||||
self.register_parameter('mlp_bias', None)
|
||||
else:
|
||||
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
|
||||
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
|
||||
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.attn_out_proj = nn.Linear(dim, dim)
|
||||
|
||||
self.mlp_drop = nn.Dropout(drop)
|
||||
self.mlp_act = act_layer()
|
||||
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
|
||||
|
||||
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
y1 = self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
||||
y2 = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
x = x + y1 + y2
|
||||
B, N, C = x.shape
|
||||
|
||||
# Combined MLP fc1 & qkv projections
|
||||
y = self.in_norm(x)
|
||||
if self.mlp_bias is not None:
|
||||
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
|
||||
# Appears faster than adding to x_mlp separately
|
||||
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
|
||||
else:
|
||||
y = self.in_proj(y)
|
||||
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
|
||||
|
||||
# Dot product attention w/ qk norm
|
||||
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
if self.fast_attn:
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x_attn = attn @ v
|
||||
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
||||
x_attn = self.attn_out_proj(x_attn)
|
||||
|
||||
# MLP activation, dropout, fc2
|
||||
x_mlp = self.mlp_act(x_mlp)
|
||||
x_mlp = self.mlp_drop(x_mlp)
|
||||
x_mlp = self.mlp_out_proj(x_mlp)
|
||||
|
||||
# Add residual w/ drop path & layer scale applied
|
||||
y = self.drop_path(self.ls(x_attn + x_mlp))
|
||||
x = x + y
|
||||
return x
|
||||
|
||||
|
||||
@ -1249,6 +1294,7 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||
|
||||
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
|
||||
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
|
||||
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
|
||||
})
|
||||
@ -1750,6 +1796,19 @@ def flexivit_large(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_xp_224(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
|
||||
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
||||
|
Loading…
x
Reference in New Issue
Block a user