renaming models and update checkpoint to vit-only

This commit is contained in:
berniebear 2025-04-25 21:26:34 +00:00
parent 3af564f2fb
commit b82f9e869c

View File

@ -14,9 +14,26 @@ from torch.amp import autocast
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
### Import timm layers ### Import timm layers
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ from timm.layers import (
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ PatchEmbed,
get_act_layer, get_norm_layer, LayerType, LayerScale Mlp,
DropPath,
AttentionPoolLatent,
RmsNorm,
PatchDropout,
SwiGLUPacked,
SwiGLU,
trunc_normal_,
lecun_normal_,
resample_patch_embed,
resample_abs_pos_embed,
use_fused_attn,
get_act_layer,
get_norm_layer,
LayerType,
LayerScale,
)
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible # from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -28,20 +45,22 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
__all__ = ['PE'] __all__ = ['PE']
####### PE's Rope ######## ######## PE's Rope ########
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
def rotate_half(x): def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2) x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1) x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1) x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)") return rearrange(x, "... d r -> ... (d r)")
@autocast("cuda", enabled=False) @autocast("cuda", enabled=False)
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
dtype = t.dtype dtype = t.dtype
@ -73,9 +92,7 @@ class RotaryEmbedding(Module):
self, self,
dim, dim,
custom_freqs: Optional[Tensor] = None, custom_freqs: Optional[Tensor] = None,
freqs_for: Union[ freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
Literal["lang"], Literal["pixel"], Literal["constant"]
] = "lang",
theta=10000, theta=10000,
max_freq=10, max_freq=10,
num_freqs=1, num_freqs=1,
@ -99,9 +116,7 @@ class RotaryEmbedding(Module):
if exists(custom_freqs): if exists(custom_freqs):
freqs = custom_freqs freqs = custom_freqs
elif freqs_for == "lang": elif freqs_for == "lang":
freqs = 1.0 / ( freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel": elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant": elif freqs_for == "constant":
@ -154,9 +169,7 @@ class RotaryEmbedding(Module):
self.register_buffer(key, value, persistent=False) self.register_buffer(key, value, persistent=False)
def get_seq_pos(self, seq_len, device, dtype, offset=0): def get_seq_pos(self, seq_len, device, dtype, offset=0):
return ( return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
torch.arange(seq_len, device=device, dtype=dtype) + offset
) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
seq_dim = default(seq_dim, self.default_seq_dim) seq_dim = default(seq_dim, self.default_seq_dim)
@ -184,9 +197,7 @@ class RotaryEmbedding(Module):
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
assert q_len <= k_len assert q_len <= k_len
rotated_q = self.rotate_queries_or_keys( rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
q, seq_dim=seq_dim, offset=k_len - q_len + offset
)
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
rotated_q = rotated_q.type(q.dtype) rotated_q = rotated_q.type(q.dtype)
@ -222,11 +233,7 @@ class RotaryEmbedding(Module):
should_cache = self.cache_if_possible and exists(seq_len) should_cache = self.cache_if_possible and exists(seq_len)
if ( if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
should_cache
and exists(self.cached_scales)
and (seq_len + offset) <= self.cached_scales.shape[0]
):
return self.cached_scales[offset : (offset + seq_len)] return self.cached_scales[offset : (offset + seq_len)]
scale = 1.0 scale = 1.0
@ -264,17 +271,10 @@ class RotaryEmbedding(Module):
@autocast("cuda", enabled=False) @autocast("cuda", enabled=False)
def forward(self, t: Tensor, seq_len=None, offset=0): def forward(self, t: Tensor, seq_len=None, offset=0):
should_cache = ( should_cache = (
self.cache_if_possible self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel"
and not self.learned_freq
and exists(seq_len)
and self.freqs_for != "pixel"
) )
if ( if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]:
should_cache
and exists(self.cached_freqs)
and (offset + seq_len) <= self.cached_freqs.shape[0]
):
return self.cached_freqs[offset : (offset + seq_len)].detach() return self.cached_freqs[offset : (offset + seq_len)].detach()
freqs = self.freqs freqs = self.freqs
@ -317,9 +317,7 @@ class Rope2D:
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
if self.use_cls_token: if self.use_cls_token:
freq = torch.cat( freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0)
[freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0
)
self.freq = freq[None, ...] self.freq = freq[None, ...]
@ -332,8 +330,8 @@ class Rope2D:
return q, k return q, k
####### PE's Modules ########
######## PE Modules ########
class AttentionPooling(nn.Module): class AttentionPooling(nn.Module):
def __init__( def __init__(
self, self,
@ -349,14 +347,10 @@ class AttentionPooling(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
assert ( assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim % num_heads == 0
), "embed_dim must be divisible by num_heads"
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
self.attn = nn.MultiheadAttention( self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
self.embed_dim, self.num_heads, batch_first=True
)
self.layernorm = norm_layer(embed_dim) self.layernorm = norm_layer(embed_dim)
self.mlp_width = int(embed_dim * mlp_ratio) self.mlp_width = int(embed_dim * mlp_ratio)
@ -396,9 +390,7 @@ class SelfAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert ( assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
# To make this compatibile with nn.MultiHeadAttention # To make this compatibile with nn.MultiHeadAttention
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
@ -418,13 +410,7 @@ class SelfAttention(nn.Module):
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = ( proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
proj.unflatten(-1, (3, embed_dim))
.unsqueeze(0)
.transpose(0, -2)
.squeeze(-2)
.contiguous()
)
q, k, v = proj[0], proj[1], proj[2] q, k, v = proj[0], proj[1], proj[2]
# Use "q_" so that we don't accidentally quit in pdb :) # Use "q_" so that we don't accidentally quit in pdb :)
@ -435,9 +421,7 @@ class SelfAttention(nn.Module):
if self.rope: if self.rope:
q, k = self.rope(q, k) q, k = self.rope(q, k)
attn = F.scaled_dot_product_attention( attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
)
attn = rearrange(attn, "b h s d -> b s (h d)") attn = rearrange(attn, "b h s d -> b s (h d)")
return F.linear(attn, self.out_proj.weight, self.out_proj.bias) return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
@ -462,16 +446,8 @@ class ResidualAttentionBlock(nn.Module):
else: else:
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
self.ls_1 = ( self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
LayerScale(d_model, ls_init_value) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if ls_init_value is not None
else nn.Identity()
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ln_1 = norm_layer(d_model) self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model) self.ln_2 = norm_layer(d_model)
@ -511,9 +487,7 @@ class ResidualAttentionBlock(nn.Module):
x: torch.Tensor, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
): ):
x = x + self.drop_path1( x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
)
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
return x return x
@ -666,7 +640,6 @@ class PE(nn.Module):
self.init_tensors() self.init_tensors()
def init_tensors(self): def init_tensors(self):
def init_submodule_tensors(module): def init_submodule_tensors(module):
for name, child in module.named_children(): for name, child in module.named_children():
@ -687,16 +660,11 @@ class PE(nn.Module):
if self.use_abs_posemb: if self.use_abs_posemb:
self.posemb_grid_size = self.image_size // self.patch_size self.posemb_grid_size = self.image_size // self.patch_size
self.positional_embedding = nn.Parameter( self.positional_embedding = nn.Parameter(
init_scale init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
* torch.randn(
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
)
) )
if self.proj_dim is not None: if self.proj_dim is not None:
self.proj = nn.Parameter( self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
init_scale * torch.randn(self.width, self.proj_dim)
)
def truncate(self, layer_idx: int): def truncate(self, layer_idx: int):
"""Delete layers so the last layer is the given layer index.""" """Delete layers so the last layer is the given layer index."""
@ -717,13 +685,9 @@ class PE(nn.Module):
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
pos_embed = ( pos_embed = (
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous()
.permute(0, 3, 1, 2)
.contiguous()
)
pos_embed = F.interpolate(
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
) )
pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
if self.use_cls_token: if self.use_cls_token:
@ -743,13 +707,7 @@ class PE(nn.Module):
else: else:
raise NotImplementedError raise NotImplementedError
def forward_features( def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
self,
x: torch.Tensor,
norm: bool = False,
layer_idx: int = -1,
strip_cls_token: bool = False
):
batch, _, h, w = x.shape batch, _, h, w = x.shape
grid_h, grid_w = h // self.patch_size, w // self.patch_size grid_h, grid_w = h // self.patch_size, w // self.patch_size
@ -791,13 +749,7 @@ class PE(nn.Module):
def checkpoint_filter_fn( def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor], state_dict: Dict[str, torch.Tensor],
model: PE,
adapt_layer_scale: bool = False,
interpolation: str = 'bicubic',
antialias: bool = True,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
if any(k.startswith("visual.") for k in state_dict): if any(k.startswith("visual.") for k in state_dict):
@ -805,6 +757,7 @@ def checkpoint_filter_fn(
return state_dict return state_dict
######## PE Config ########
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
return { return {
'license': 'apache-2.0', 'license': 'apache-2.0',
@ -813,34 +766,24 @@ def _cfg(url='', **kwargs):
'fixed_input_size': True, 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD, 'std': IMAGENET_INCEPTION_STD,
**kwargs **kwargs,
} }
default_cfgs = generate_default_cfgs({
'pe_core_b16_224': _cfg( default_cfgs = generate_default_cfgs(
hf_hub_id='timm/', {
input_size=(3, 224, 224)), 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
'pe_core_l14_336': _cfg( 'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
hf_hub_id='timm/', 'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
input_size=(3, 336, 336)), 'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
'pe_core_G14_448': _cfg( 'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
hf_hub_id='timm/', 'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
input_size=(3, 448, 448)), }
'pe_lang_l14_448': _cfg( )
hf_hub_id='timm/',
input_size=(3, 448, 448)),
'pe_lang_G14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
'pe_spatial_G14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
})
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
out_indices = kwargs.pop('out_indices', 3) out_indices = kwargs.pop('out_indices', 3)
return build_model_with_cfg( return build_model_with_cfg(
PE, PE,
variant, variant,
@ -851,8 +794,9 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
**kwargs, **kwargs,
) )
@register_model @register_model
def pe_core_b16_224(pretrained=False, **kwargs): def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
image_size=224, image_size=224,
patch_size=16, patch_size=16,
@ -864,10 +808,11 @@ def pe_core_b16_224(pretrained=False, **kwargs):
use_cls_token=True, use_cls_token=True,
pool_type='attn', pool_type='attn',
) )
return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def pe_core_l14_336(pretrained=False, **kwargs): def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
image_size=336, image_size=336,
patch_size=14, patch_size=14,
@ -879,11 +824,11 @@ def pe_core_l14_336(pretrained=False, **kwargs):
use_cls_token=True, use_cls_token=True,
pool_type='attn', pool_type='attn',
) )
return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def pe_core_G14_448(pretrained=False, **kwargs): def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
image_size=448, image_size=448,
patch_size=14, patch_size=14,
@ -895,27 +840,11 @@ def pe_core_G14_448(pretrained=False, **kwargs):
use_cls_token=False, use_cls_token=False,
pool_type='attn', pool_type='attn',
) )
return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def pe_lang_G14_448(pretrained=False, **kwargs): def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size = 448,
patch_size = 14,
width = 1536,
layers = 47,
heads = 16,
mlp_ratio = 8960 / 1536,
output_dim = None,
use_cls_token = False,
use_ln_post = False,
pool_type = 'none',
ls_init_value = 0.1,
)
return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_lang_l14_448(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
image_size=448, image_size=448,
patch_size=14, patch_size=14,
@ -929,10 +858,29 @@ def pe_lang_l14_448(pretrained=False, **kwargs):
pool_type='none', pool_type='none',
ls_init_value=0.1, ls_init_value=0.1,
) )
return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def pe_spatial_G14_448(pretrained=False, **kwargs): def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size=448,
patch_size=14,
width=1536,
layers=47,
heads=16,
mlp_ratio=8960 / 1536,
output_dim=None,
use_cls_token=False,
use_ln_post=False,
pool_type='none',
ls_init_value=0.1,
)
return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
image_size=448, image_size=448,
patch_size=14, patch_size=14,
@ -946,4 +894,4 @@ def pe_spatial_G14_448(pretrained=False, **kwargs):
pool_type='none', pool_type='none',
ls_init_value=0.1, ls_init_value=0.1,
) )
return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))