renaming models and update checkpoint to vit-only

This commit is contained in:
berniebear 2025-04-25 21:26:34 +00:00
parent 3af564f2fb
commit b82f9e869c

View File

@ -14,10 +14,27 @@ from torch.amp import autocast
from torch.utils.checkpoint import checkpoint
### Import timm layers
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType, LayerScale
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
from timm.layers import (
PatchEmbed,
Mlp,
DropPath,
AttentionPoolLatent,
RmsNorm,
PatchDropout,
SwiGLUPacked,
SwiGLU,
trunc_normal_,
lecun_normal_,
resample_patch_embed,
resample_abs_pos_embed,
use_fused_attn,
get_act_layer,
get_norm_layer,
LayerType,
LayerScale,
)
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from ._builder import build_model_with_cfg
@ -28,20 +45,22 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep
__all__ = ['PE']
####### PE's Rope ########
######## PE's Rope ########
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
@autocast("cuda", enabled=False)
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
dtype = t.dtype
@ -73,9 +92,7 @@ class RotaryEmbedding(Module):
self,
dim,
custom_freqs: Optional[Tensor] = None,
freqs_for: Union[
Literal["lang"], Literal["pixel"], Literal["constant"]
] = "lang",
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
theta=10000,
max_freq=10,
num_freqs=1,
@ -99,9 +116,7 @@ class RotaryEmbedding(Module):
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
@ -154,9 +169,7 @@ class RotaryEmbedding(Module):
self.register_buffer(key, value, persistent=False)
def get_seq_pos(self, seq_len, device, dtype, offset=0):
return (
torch.arange(seq_len, device=device, dtype=dtype) + offset
) / self.interpolate_factor
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
seq_dim = default(seq_dim, self.default_seq_dim)
@ -184,9 +197,7 @@ class RotaryEmbedding(Module):
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
assert q_len <= k_len
rotated_q = self.rotate_queries_or_keys(
q, seq_dim=seq_dim, offset=k_len - q_len + offset
)
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
rotated_q = rotated_q.type(q.dtype)
@ -222,11 +233,7 @@ class RotaryEmbedding(Module):
should_cache = self.cache_if_possible and exists(seq_len)
if (
should_cache
and exists(self.cached_scales)
and (seq_len + offset) <= self.cached_scales.shape[0]
):
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
return self.cached_scales[offset : (offset + seq_len)]
scale = 1.0
@ -264,17 +271,10 @@ class RotaryEmbedding(Module):
@autocast("cuda", enabled=False)
def forward(self, t: Tensor, seq_len=None, offset=0):
should_cache = (
self.cache_if_possible
and not self.learned_freq
and exists(seq_len)
and self.freqs_for != "pixel"
self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel"
)
if (
should_cache
and exists(self.cached_freqs)
and (offset + seq_len) <= self.cached_freqs.shape[0]
):
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]:
return self.cached_freqs[offset : (offset + seq_len)].detach()
freqs = self.freqs
@ -289,7 +289,7 @@ class RotaryEmbedding(Module):
class Rope2D:
""" Helper class to apply RoPE2D as well as interpolate on the fly. """
"""Helper class to apply RoPE2D as well as interpolate on the fly."""
def __init__(self, dim, use_cls_token=False):
self.dim = dim
@ -313,13 +313,11 @@ class Rope2D:
grid_y_range = torch.arange(grid_h, device=device)
grid_x_range = torch.arange(grid_w, device=device)
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
if self.use_cls_token:
freq = torch.cat(
[freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0
)
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0)
self.freq = freq[None, ...]
@ -332,8 +330,8 @@ class Rope2D:
return q, k
####### PE's Modules ########
######## PE Modules ########
class AttentionPooling(nn.Module):
def __init__(
self,
@ -349,14 +347,10 @@ class AttentionPooling(nn.Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
assert (
self.embed_dim % num_heads == 0
), "embed_dim must be divisible by num_heads"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
self.attn = nn.MultiheadAttention(
self.embed_dim, self.num_heads, batch_first=True
)
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
self.layernorm = norm_layer(embed_dim)
self.mlp_width = int(embed_dim * mlp_ratio)
@ -396,9 +390,7 @@ class SelfAttention(nn.Module):
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
# To make this compatibile with nn.MultiHeadAttention
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
@ -418,13 +410,7 @@ class SelfAttention(nn.Module):
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
proj = (
proj.unflatten(-1, (3, embed_dim))
.unsqueeze(0)
.transpose(0, -2)
.squeeze(-2)
.contiguous()
)
proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
q, k, v = proj[0], proj[1], proj[2]
# Use "q_" so that we don't accidentally quit in pdb :)
@ -435,9 +421,7 @@ class SelfAttention(nn.Module):
if self.rope:
q, k = self.rope(q, k)
attn = F.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale
)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
attn = rearrange(attn, "b h s d -> b s (h d)")
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
@ -462,16 +446,8 @@ class ResidualAttentionBlock(nn.Module):
else:
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
self.ls_1 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ls_2 = (
LayerScale(d_model, ls_init_value)
if ls_init_value is not None
else nn.Identity()
)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
@ -511,9 +487,7 @@ class ResidualAttentionBlock(nn.Module):
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
):
x = x + self.drop_path1(
self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))
)
x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
return x
@ -558,9 +532,9 @@ class Transformer(nn.Module):
@torch.jit.ignore
def truncate(self, layer_idx: int):
""" Delete layers so the last layer is the given layer index. """
"""Delete layers so the last layer is the given layer index."""
self.layers = ((self.layers + layer_idx) % self.layers) + 1
self.resblocks = nn.ModuleList(self.resblocks[:self.layers])
self.resblocks = nn.ModuleList(self.resblocks[: self.layers])
def forward(
self,
@ -576,7 +550,7 @@ class Transformer(nn.Module):
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
if i == stop_idx:
break
@ -604,7 +578,7 @@ class PE(nn.Module):
output_dim: Optional[int] = 1280,
attn_pooler_heads: int = 8,
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
num_classes: int = 1000, # no use for now
num_classes: int = 1000, # no use for now
in_chans: int = 3,
):
super().__init__()
@ -666,12 +640,11 @@ class PE(nn.Module):
self.init_tensors()
def init_tensors(self):
def init_submodule_tensors(module):
for name, child in module.named_children():
if hasattr(child, "init_tensors"):
#logger.debug(f"Initializing tensors for submodule: {name}")
# logger.debug(f"Initializing tensors for submodule: {name}")
child.init_tensors()
init_submodule_tensors(child)
@ -687,19 +660,14 @@ class PE(nn.Module):
if self.use_abs_posemb:
self.posemb_grid_size = self.image_size // self.patch_size
self.positional_embedding = nn.Parameter(
init_scale
* torch.randn(
int(self.use_cls_token) + self.posemb_grid_size**2, self.width
)
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
)
if self.proj_dim is not None:
self.proj = nn.Parameter(
init_scale * torch.randn(self.width, self.proj_dim)
)
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
def truncate(self, layer_idx: int):
""" Delete layers so the last layer is the given layer index. """
"""Delete layers so the last layer is the given layer index."""
self.transformer.truncate(layer_idx)
self.layers = self.transformer.layers
@ -717,13 +685,9 @@ class PE(nn.Module):
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
pos_embed = (
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1)
.permute(0, 3, 1, 2)
.contiguous()
)
pos_embed = F.interpolate(
pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous()
)
pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False)
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
if self.use_cls_token:
@ -743,13 +707,7 @@ class PE(nn.Module):
else:
raise NotImplementedError
def forward_features(
self,
x: torch.Tensor,
norm: bool = False,
layer_idx: int = -1,
strip_cls_token: bool = False
):
def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
batch, _, h, w = x.shape
grid_h, grid_w = h // self.patch_size, w // self.patch_size
@ -790,14 +748,8 @@ class PE(nn.Module):
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: PE,
adapt_layer_scale: bool = False,
interpolation: str = 'bicubic',
antialias: bool = True,
state_dict: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
state_dict = state_dict.get('model', state_dict)
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
if any(k.startswith("visual.") for k in state_dict):
@ -805,42 +757,33 @@ def checkpoint_filter_fn(
return state_dict
######## PE Config ########
def _cfg(url='', **kwargs):
return {
'license': 'apache-2.0',
'num_classes': 0,
'interpolation': 'bilinear',
'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN,
'mean': IMAGENET_INCEPTION_MEAN,
'std': IMAGENET_INCEPTION_STD,
**kwargs
**kwargs,
}
default_cfgs = generate_default_cfgs({
'pe_core_b16_224': _cfg(
hf_hub_id='timm/',
input_size=(3, 224, 224)),
'pe_core_l14_336': _cfg(
hf_hub_id='timm/',
input_size=(3, 336, 336)),
'pe_core_G14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
'pe_lang_l14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
'pe_lang_G14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
'pe_spatial_G14_448': _cfg(
hf_hub_id='timm/',
input_size=(3, 448, 448)),
})
default_cfgs = generate_default_cfgs(
{
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
}
)
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
out_indices = kwargs.pop('out_indices', 3)
return build_model_with_cfg(
PE,
variant,
@ -851,99 +794,104 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
**kwargs,
)
@register_model
def pe_core_b16_224(pretrained=False, **kwargs):
model_args = dict(
image_size = 224,
patch_size = 16,
width = 768,
layers = 12,
heads = 12,
mlp_ratio = 4.0,
output_dim = 1024,
use_cls_token = True,
pool_type = 'attn',
)
return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_core_l14_336(pretrained=False, **kwargs):
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
model_args = dict(
image_size = 336,
patch_size = 14,
width = 1024,
layers = 24,
heads = 16,
mlp_ratio = 4.0,
output_dim = 1024,
use_cls_token = True,
pool_type = 'attn',
image_size=224,
patch_size=16,
width=768,
layers=12,
heads=12,
mlp_ratio=4.0,
output_dim=1024,
use_cls_token=True,
pool_type='attn',
)
return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs))
return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_core_G14_448(pretrained=False, **kwargs):
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
model_args = dict(
image_size = 448,
patch_size = 14,
width = 1536,
layers = 50,
heads = 16,
mlp_ratio = 8960 / 1536,
output_dim = 1280,
use_cls_token = False,
pool_type = 'attn',
image_size=336,
patch_size=14,
width=1024,
layers=24,
heads=16,
mlp_ratio=4.0,
output_dim=1024,
use_cls_token=True,
pool_type='attn',
)
return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_lang_G14_448(pretrained=False, **kwargs):
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size = 448,
patch_size = 14,
width = 1536,
layers = 47,
heads = 16,
mlp_ratio = 8960 / 1536,
output_dim = None,
use_cls_token = False,
use_ln_post = False,
pool_type = 'none',
ls_init_value = 0.1,
image_size=448,
patch_size=14,
width=1536,
layers=50,
heads=16,
mlp_ratio=8960 / 1536,
output_dim=1280,
use_cls_token=False,
pool_type='attn',
)
return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_lang_l14_448(pretrained=False, **kwargs):
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size = 448,
patch_size = 14,
width = 1024,
layers = 23,
heads = 16,
mlp_ratio = 4.0,
output_dim = None,
use_cls_token = True,
use_ln_post = False,
pool_type = 'none',
ls_init_value = 0.1,
image_size=448,
patch_size=14,
width=1024,
layers=23,
heads=16,
mlp_ratio=4.0,
output_dim=None,
use_cls_token=True,
use_ln_post=False,
pool_type='none',
ls_init_value=0.1,
)
return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def pe_spatial_G14_448(pretrained=False, **kwargs):
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size = 448,
patch_size = 14,
width = 1536,
layers = 50,
heads = 16,
mlp_ratio = 8960 / 1536,
output_dim = None,
use_cls_token = False,
use_ln_post = False,
pool_type = 'none',
ls_init_value = 0.1,
image_size=448,
patch_size=14,
width=1536,
layers=47,
heads=16,
mlp_ratio=8960 / 1536,
output_dim=None,
use_cls_token=False,
use_ln_post=False,
pool_type='none',
ls_init_value=0.1,
)
return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs))
return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
model_args = dict(
image_size=448,
patch_size=14,
width=1536,
layers=50,
heads=16,
mlp_ratio=8960 / 1536,
output_dim=None,
use_cls_token=False,
use_ln_post=False,
pool_type='none',
ls_init_value=0.1,
)
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))